xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision 0c59f51592ef5c014352994369f5216c6376fae1)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/ADT/MapVector.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 
31 #include <set>
32 
33 #define DEBUG_TYPE "linalg-fusion"
34 
35 using namespace mlir;
36 using namespace mlir::edsc;
37 using namespace mlir::edsc::intrinsics;
38 using namespace mlir::linalg;
39 
40 using llvm::dbgs;
41 
42 /// Implements a simple high-level fusion pass on linalg structured operations.
43 ///
44 /// In each block, linalg ops are processed in reverse textual order.
45 /// Given a linalg op `O`, fusion occurs by:
46 ///   1. inspecting the linalg ops that write into the views read by `O`. There
47 ///      are 2 cases:
48 ///      a) buffer case: use the SSA value of the views and a simple alias
49 ///         analysis on subview ops to determine producer-consumer dependences;
50 ///      b) tensor case: use SSA use-def chains on subtensor ops;
51 ///   2. greedily fuse the linalg ops that produce the subview/subtensor.
52 ///   3. inspect the fused ops and determine whether they have other remaining
53 ///      LinalgOp uses. If not, then erase the original producing linalg op.
54 ///
55 /// More advanced use cases, analyses as well as profitability heuristics are
56 /// left for future work.
57 
58 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
59 // by `permutationMap`.
60 static void inferShapeComponents(AffineMap permutationMap,
61                                  ArrayRef<Range> loopRanges,
62                                  SmallVectorImpl<Value> &offsets,
63                                  SmallVectorImpl<Value> &sizes,
64                                  SmallVectorImpl<Value> &strides) {
65   assert(permutationMap.isProjectedPermutation() &&
66          "expected some subset of a permutation map");
67   SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
68   unsigned idx = 0;
69   for (AffineExpr e : permutationMap.getResults()) {
70     // loopToOperandRangesMaps are permutations-only, just swap indices.
71     unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
72     shapeRanges[idx++] = loopRanges[loopPos];
73   }
74   // Construct a new subshape for the tile.
75   unsigned rank = shapeRanges.size();
76   offsets.reserve(rank);
77   sizes.reserve(rank);
78   strides.reserve(rank);
79   for (auto r : shapeRanges) {
80     offsets.push_back(r.offset);
81     sizes.push_back(r.size);
82     strides.push_back(r.stride);
83   }
84 }
85 
86 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
87 // a subset of the original loop ranges of `op`.
88 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
89 // to the `loopRanges` in order to obtain view ranges.
90 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
91                                     ArrayRef<Range> loopRanges) {
92   SmallVector<Value, 8> clonedShapes;
93   clonedShapes.reserve(op.getNumShapedOperands());
94 
95   // Iterate over the shape operands in order.
96   // Extract the subranges from the linearized ranges.
97   for (auto en : llvm::enumerate(op.getShapedOperands())) {
98     unsigned shapedOperandIdx = en.index();
99     AffineMap map = op.getIndexingMap(shapedOperandIdx);
100     LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
101                             << " with indexingMap: " << map << "\n");
102     SmallVector<Value, 4> offsets, sizes, strides;
103     inferShapeComponents(map, loopRanges, offsets, sizes, strides);
104     Value shape = en.value();
105     Value sub = shape.getType().isa<MemRefType>()
106                     ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides)
107                           .getResult()
108                     : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
109                           .getResult();
110     clonedShapes.push_back(sub);
111   }
112   // Append the other operands.
113   auto operands = op.getAssumedNonShapedOperands();
114   clonedShapes.append(operands.begin(), operands.end());
115 
116   // Iterate over the results in order.
117   // Extract the subtensor type from the linearized range.
118   // Since we do not enforce any canonicalizations on the fly, this is always
119   // fully dynamic at construction time.
120   SmallVector<Type, 4> resultTypes;
121   resultTypes.reserve(op.getOperation()->getNumResults());
122   for (RankedTensorType t : op.getOutputTensorTypes()) {
123     unsigned rank = t.getRank();
124     SmallVector<int64_t, 4> staticOffsetsVector(
125         rank, ShapedType::kDynamicStrideOrOffset);
126     SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
127     SmallVector<int64_t, 4> staticStridesVector(
128         rank, ShapedType::kDynamicStrideOrOffset);
129     resultTypes.push_back(SubTensorOp::inferResultType(
130         t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
131         staticStridesVector));
132   }
133 
134   Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
135   // When the producer is an IndexedGenericOp, we have to transform its block
136   // IV arguments according to the tiling of the consumer, i.e. offset them by
137   // the values computed in `loopRanges`.
138   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
139     auto &block = indexedGenericOp.region().front();
140     OpBuilder::InsertionGuard g(b);
141     b.setInsertionPointToStart(&block);
142     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
143       Value oldIndex = block.getArgument(i);
144       // TODO: replace by an affine_apply.
145       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
146                                          loopRanges[i].offset);
147       oldIndex.replaceAllUsesExcept(newIndex,
148                                     SmallPtrSet<Operation *, 1>{newIndex});
149     }
150   }
151 
152   return clonedOp;
153 }
154 
155 struct ShapeDimension {
156   Value shape;
157   unsigned dimension;
158 };
159 
160 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
161 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
162 // guarantees at least one such dimension is found. If multiple candidates exist
163 // they must agree by construction (i.e. have the same size) and we just return
164 // the first one.
165 static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
166                                                 unsigned loopDepth) {
167   auto maps = op.indexing_maps();
168   // Iterate over the inputs and outputs in order.
169   // Extract the subranges from the linearized ranges.
170   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
171   for (auto en : llvm::enumerate(ios)) {
172     unsigned idx = en.index();
173     auto map = maps[idx].cast<AffineMapAttr>().getValue();
174     LLVM_DEBUG(llvm::dbgs()
175                << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
176     LLVM_DEBUG(llvm::dbgs()
177                << "getShapeDefiningLoopRange map: " << map << "\n");
178     Value shape = en.value();
179     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
180     for (auto en2 : llvm::enumerate(map.getResults())) {
181       auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
182       if (!dimExpr)
183         continue;
184       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
185         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
186                                 << loopDepth << "\n");
187         LLVM_DEBUG(llvm::dbgs()
188                    << "getShapeDefiningLoopRange shape: " << shape << "\n");
189         return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
190       }
191     }
192   }
193   llvm_unreachable("Expect to be able to extract a shape defining loop range");
194 }
195 
196 /// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
197 /// provides the loop range information for the fused loops. The rest are
198 /// obtained from the producer itself, since they are not tiled + fused.
199 static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
200                      const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
201 
202   unsigned nPar = producer.getNumParallelLoops();
203   unsigned nRed = producer.getNumReductionLoops();
204   unsigned nWin = producer.getNumWindowLoops();
205   SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
206   for (auto fusedLoops : fusedLoopsAndRanges)
207     loopRanges[fusedLoops.first] = fusedLoops.second;
208 
209   // Iterate over all dimensions. For the dimensions not identified by the
210   // producer map for `producerIdx`, we need to explicitly compute the shape
211   // that defines the loop ranges using the `producer`.
212   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
213     if (loopRanges[i].offset)
214       LLVM_DEBUG(llvm::dbgs()
215                  << "existing LoopRange: " << loopRanges[i] << "\n");
216     else {
217       auto shapeDim = getShapeDefiningLoopRange(producer, i);
218       loopRanges[i] = Range{std_constant_index(0),
219                             std_dim(shapeDim.shape, shapeDim.dimension),
220                             std_constant_index(1)};
221       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
222     }
223   }
224 
225   return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
226 }
227 
228 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
229 /// expected to be defined by a subview op or a subtensor op.
230 static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
231                                       Value shapedOperand, unsigned dim) {
232   Operation *shapeProducingOp = shapedOperand.getDefiningOp();
233   if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
234     return subViewOp.getOrCreateRanges(b, loc)[dim];
235   if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
236     return subTensorOp.getOrCreateRanges(b, loc)[dim];
237   llvm_unreachable("SubviewOp or SubTensorOp expected");
238 }
239 
240 /// Fuses the producer of `producerIdx` into the loop immediately enclosing
241 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it
242 /// is needed just before the `consumer.
243 ///
244 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
245 /// 2 cases:
246 ///   1. Buffer case: `producerIdx` is the index of the buffer in
247 ///      `producer.getOutputBuffers()`.
248 ///   2. Tensor case: `producerIdx` is the index of the tensor in
249 ///      `producer.getResults()`.
250 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
251                      LinalgOp consumer, unsigned consumerIdx) {
252   AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
253   LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
254                           << ", producer map: " << producerMap << "\n");
255   DenseMap<unsigned, Range> fusedLoopsAndRanges;
256   Location loc = consumer.getLoc();
257   Value shapedOperand = consumer.getShapedOperand(consumerIdx);
258   for (auto en : llvm::enumerate(producerMap.getResults())) {
259     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
260     fusedLoopsAndRanges[posInProducerLoop] =
261         getRangeFromOperandShape(b, loc, shapedOperand, en.index());
262   }
263   return fuse(b, producer, fusedLoopsAndRanges);
264 }
265 
266 // Encode structural fusion safety preconditions.
267 // Some of these will be lifted in the future with better analysis.
268 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
269                                           LinalgOp consumer) {
270   assert(producer.hasBufferSemantics() &&
271          "expected linalg op with buffer semantics");
272   assert(consumer.hasBufferSemantics() &&
273          "expected linalg op with buffer semantics");
274   if (producer.getNumOutputs() != 1) {
275     LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
276     return false;
277   }
278   // Only fuse when the producer block dominates.
279   DominanceInfo dom(producer.getOperation());
280   if (!dom.dominates(producer.getOperation()->getBlock(),
281                      consumer.getOperation()->getBlock())) {
282     LLVM_DEBUG(
283         llvm::dbgs()
284         << "\nNot structurally fusable (producer block does not dominate)");
285     return false;
286   }
287   return true;
288 }
289 
290 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
291                                              LinalgOp consumer,
292                                              Value consumedView,
293                                              LinalgOp producer) {
294   assert(producer.hasBufferSemantics() &&
295          "expected linalg op with buffer semantics");
296   assert(consumer.hasBufferSemantics() &&
297          "expected linalg op with buffer semantics");
298   // Make some simple structural checks that alleviate the need for more
299   // complex analyses.
300   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
301     LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
302                             << *producer.getOperation());
303     return false;
304   }
305   // Check for any interleaved write to consumedView.
306   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
307     LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
308                             << *producer.getOperation());
309     return false;
310   }
311   return true;
312 }
313 
314 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
315                                  LinalgOp consumer, Value consumedView,
316                                  LinalgOp producer) {
317   assert(producer.hasBufferSemantics() &&
318          "expected linalg op with buffer semantics");
319   assert(consumer.hasBufferSemantics() &&
320          "expected linalg op with buffer semantics");
321   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
322     return false;
323   // Check for any fusion-preventing dependence to any shape read/written that
324   // would violate dependences.
325   if (!graph.findCoveringDependences(producer, consumer).empty()) {
326     LLVM_DEBUG(llvm::dbgs()
327                << "\n***Not fusable due to an interleaved dependence:\t"
328                << *producer.getOperation());
329     return false;
330   }
331   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
332     // TODO: add a level of indirection to linalg.generic.
333     if (convOp.padding())
334       return false;
335   }
336   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
337     // TODO: add a level of indirection to linalg.generic.
338     if (convOp.padding())
339       return false;
340   }
341   return true;
342 }
343 
344 static bool isSameSubView(Value a, Value b) {
345   if (a == b)
346     return true;
347   auto sva = a.getDefiningOp<SubViewOp>();
348   auto svb = b.getDefiningOp<SubViewOp>();
349   if (!sva || !svb)
350     return false;
351   if (!isSameSubView(sva.getViewSource(), svb.getViewSource()))
352     return false;
353   if (sva.getType() != svb.getType())
354     return false;
355   if (sva.getNumOperands() != svb.getNumOperands())
356     return false;
357   if (sva.static_offsets() != svb.static_offsets())
358     return false;
359   if (sva.static_sizes() != svb.static_sizes())
360     return false;
361   if (sva.static_strides() != svb.static_strides())
362     return false;
363   /// Skip the "source" operand.
364   for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
365     if (sva.getOperand(idx) != svb.getOperand(idx))
366       return false;
367   return true;
368 }
369 
370 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
371 findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
372                     const LinalgDependenceGraph &dependenceGraph) {
373   // Only consider RAW and WAW atm.
374   for (auto depType : {
375            LinalgDependenceGraph::DependenceType::RAW,
376            LinalgDependenceGraph::DependenceType::WAW,
377        }) {
378     for (auto dependence : llvm::make_filter_range(
379              dependenceGraph.getDependencesInto(consumer, depType),
380              [consumerIdx](
381                  LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
382                return elem.indexingOpView.operandIndex == consumerIdx;
383              })) {
384       auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
385 
386       // Check that the dependence is indeed on the input `consumerIdx` view.
387       auto consumedView =
388           consumer.getBuffer(dependence.indexingOpView.operandIndex);
389       if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
390         continue;
391 
392       // Consumer consumes this view, `isStructurallyFusableProducer` also
393       // checks whether it is a strict subview of the producer view.
394       auto producedView =
395           producer.getBuffer(dependence.dependentOpView.operandIndex);
396       LLVM_DEBUG(llvm::dbgs()
397                  << "\n"
398                  << LinalgDependenceGraph::getDependenceTypeStr(depType)
399                  << "producer: " << *producer.getOperation()
400                  << " view: " << producedView << " output index: "
401                  << dependence.dependentOpView.operandIndex -
402                         producer.getNumInputs()
403                  << "\n");
404       (void)producedView;
405 
406       // Simple fusability checks.
407       if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
408         continue;
409 
410       return dependence;
411     }
412   }
413   return {};
414 }
415 
416 Optional<FusionInfo>
417 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
418                                    unsigned consumerIdx,
419                                    const LinalgDependenceGraph &graph) {
420   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
421       findFusableProducer(consumer, consumerIdx, graph);
422   if (!fusableDependence)
423     return {};
424 
425   LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
426   // If producer is already in the same block as consumer, we are done.
427   if (consumer.getOperation()->getBlock() ==
428       producerOp.getOperation()->getBlock())
429     return {};
430 
431   unsigned producerIdx = fusableDependence->dependentOpView.operandIndex -
432                          producerOp.getNumInputs();
433   Value consumerView = consumer.getShapedOperand(consumerIdx);
434 
435   // Must be a subview or a slice to guarantee there are loops we can fuse
436   // into.
437   auto subView = consumerView.getDefiningOp<SubViewOp>();
438   auto slice = consumerView.getDefiningOp<SliceOp>();
439   if (!subView && !slice) {
440     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
441     return {};
442   }
443 
444   // Fuse `producer` just before `consumer`.
445   OpBuilder::InsertionGuard g(b);
446   b.setInsertionPoint(consumer.getOperation());
447   ScopedContext scope(b, consumer.getLoc());
448   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
449 
450   auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
451   return FusionInfo{producerOp, fusedProducer};
452 }
453 
454 /// Walk back use-def chain through scf::For yields.
455 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
456 static void getProducerOfTensor(Value tensor, LinalgOp &producer,
457                                 unsigned &outputIndex) {
458   if (!tensor.getType().isa<RankedTensorType>())
459     return;
460 
461   while (true) {
462     if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
463       producer = linalgOp;
464       outputIndex = tensor.cast<OpResult>().getResultNumber();
465       return;
466     }
467     if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
468       tensor = subTensorOp.source();
469       continue;
470     }
471     if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
472       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
473         tensor = forOp.getResult(blockArg.getArgNumber());
474         continue;
475       }
476     }
477     return;
478   }
479 }
480 
481 Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
482                                                         LinalgOp consumer,
483                                                         unsigned consumerIdx) {
484   Value inputTensor = consumer.getInput(consumerIdx);
485   LinalgOp producerOp;
486   unsigned producerIdx;
487   getProducerOfTensor(inputTensor, producerOp, producerIdx);
488 
489   // Must be a subtensor to guarantee there are loops we can fuse into.
490   auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
491   if (!subTensor || !producerOp) {
492     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)");
493     return {};
494   }
495 
496   // If producer is already in the same block as consumer, we are done.
497   if (consumer.getOperation()->getBlock() ==
498       producerOp.getOperation()->getBlock())
499     return {};
500 
501   // Insert fused `producer` just before `consumer`.
502   OpBuilder::InsertionGuard g(b);
503   b.setInsertionPoint(consumer.getOperation());
504   ScopedContext scope(b, consumer.getLoc());
505   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
506   LinalgOp fusedProducer =
507       fuse(b, producerOp, producerIdx, consumer, consumerIdx);
508 
509   // Replace use.
510   // Canonicalizations are not guaranteed to have happened before constructing
511   // `fusedProducer`. In the tensor case this can result in temporary type
512   // mismatches. Insert a `tensor_cast` op to propagate the transformation
513   // invariant that types are compatible.
514   Value def = fusedProducer.getOperation()->getResult(producerIdx);
515   OpOperand &use = consumer.getOperation()->getOpOperand(consumerIdx);
516   Type consumerType = use.get().getType();
517   if (consumerType != def.getType())
518     def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
519   use.set(def);
520   return FusionInfo{producerOp, fusedProducer};
521 }
522 
523 /// Prune all dimensions that are of reduction iterator type from `map`.
524 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
525                                            AffineMap map) {
526   SmallVector<unsigned, 2> projectedDims;
527   for (auto attr : llvm::enumerate(iteratorTypes)) {
528     if (!isParallelIterator(attr.value()))
529       projectedDims.push_back(attr.index());
530   }
531   return getProjectedMap(map, projectedDims);
532 }
533 
534 /// Returns the mapping from iterations in the consumer that write to the same
535 /// location as the iterations in the producer. To do so use
536 /// - indexing map of the fused view in the consumer : consumerIndexMap
537 /// - indexing map of the fused view in the producer : producerIndexMap
538 ///     consumerLoopToProducerLoop =
539 ///       inverse(producerIndexMap).compose(consumerIndexMap)
540 static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
541     LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
542   auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
543   AffineMap producerIndexingMap =
544       producer.getIndexingMap(dependence.dependentOpView.operandIndex);
545   auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
546   AffineMap consumerIndexingMap =
547       consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
548 
549   AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
550       producer.iterator_types().getValue(), producerIndexingMap);
551   if (!prunedProducerIndexingMap.isPermutation())
552     return None;
553 
554   if (consumerIndexingMap.getNumResults() !=
555       prunedProducerIndexingMap.getNumResults())
556     return None;
557 
558   LLVM_DEBUG({
559     llvm::dbgs() << "\t producerMap : ";
560     producerIndexingMap.print(llvm::dbgs());
561     llvm::dbgs() << "  pruned : ";
562     prunedProducerIndexingMap.print(llvm::dbgs());
563     llvm::dbgs() << "\n";
564     llvm::dbgs() << "\t consumerMap : ";
565     consumerIndexingMap.print(llvm::dbgs());
566     llvm::dbgs() << "\n";
567   });
568 
569   AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
570   if (!invProducerIndexMap)
571     return None;
572 
573   return invProducerIndexMap.compose(consumerIndexingMap);
574 }
575 
576 /// Given a projected permutation `map`, returns true if the map changes the
577 /// order in which the fused loop dimension appear.
578 static bool doesTransposeAccess(AffineMap map,
579                                 const std::set<unsigned> &fusableLoops) {
580   Optional<unsigned> lastFusableLoop;
581   for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
582          return expr.cast<AffineDimExpr>().getPosition();
583        })) {
584     if (!fusableLoops.count(pos))
585       continue;
586     if (!lastFusableLoop) {
587       lastFusableLoop = pos;
588       continue;
589     }
590     if (pos <= lastFusableLoop.getValue())
591       return true;
592     lastFusableLoop = pos;
593   }
594   return false;
595 }
596 
597 /// Returns the positions of the loop in `op` that can be tiled based on the
598 /// operations that are to be fused with it. For example, in a
599 ///
600 ///   linalg.matmul ins(%a, %b : ...) outs(%c : ...)
601 ///
602 /// if the producer of %a needs to be fused with this op, only the `i` loop of
603 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
604 /// fused, then no loops can be tiled while fusing. The conditions used are:
605 /// 1. Only parallel loops can be used for tile + fuse. Find the number of
606 ///    common outer parallel loops between the op and its producers being fused.
607 /// 2. Of the parallel loops only some can be fused. Only those loops can be
608 ///    fused such where the fusable loops iteration space only touches one tile
609 ///    of the fused operation. This is because the producer (which is writing
610 ///    the fused subview) has update semantics.
611 ///
612 /// Since an inverse computation is needed, we need to consider the projection
613 /// of the producerIndexMap w.r.t the parallel loops.  The actual fusable loops
614 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to
615 /// parallel loops and appear in the result of the map
616 ///
617 /// Example 1:
618 ///   linalg.fill(%c, %cst)
619 ///   linalg.matmul ins(%a, %b) outs(%c)
620 ///     Number of parallel loops : 2
621 ///     producerIndexMap = affine_map<(i, j) ->(i , j)>
622 ///     consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
623 ///     consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
624 ///     Fused dimensions : i, j
625 ///
626 /// Example 2:
627 ///   linalg.matmul ins(%a, %b) outs(%c)
628 ///   linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
629 ///                   iterator_types = ["parallel", "parallel"]}
630 ///     ins(%c) ...
631 ///
632 ///     Number of parallel loops = 2:
633 ///     producerIndexMap (projected to parallel loops) =
634 ///       affine_map<(i, j) -> (i, j)>
635 ///     consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
636 ///     Fused dimensions : i, j
637 ///
638 /// Example 3:
639 ///   linalg.copy(%s, %b)
640 ///   linalg.matmul ins(%a, %b) outs(%c)
641 ///
642 ///   Number of parallel loops = 2
643 ///   produceIndexMap : affine_map<(i, j) -> (i, j)>
644 ///   consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
645 ///     submap with only parallel loops = affine_map<(i, j) -> (j)>
646 ///   Fused dimensions : j
647 static std::set<unsigned>
648 collectFusableLoops(ArrayRef<LinalgOp> ops,
649                     const FusableOpDependencesTy &fusableDependences) {
650   assert(!ops.empty());
651   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
652     return linalgOp.iterator_types()
653         .getValue()
654         .take_while([](Attribute attr) -> bool {
655           return attr.cast<StringAttr>().getValue() ==
656                  getParallelIteratorTypeName();
657         })
658         .size();
659   };
660 
661   size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
662   for (auto op : ops.drop_back()) {
663     numOuterParallelLoops =
664         std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
665   }
666 
667   std::set<unsigned> fusableLoops;
668   auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
669   fusableLoops.insert(range.begin(), range.end());
670 
671   for (auto op : reverse(ops)) {
672     for (auto dependence : fusableDependences.lookup(op)) {
673       LLVM_DEBUG({
674         llvm::dbgs() << "\t fusable :";
675         for (unsigned i : fusableLoops)
676           llvm::dbgs() << " " << i;
677         llvm::dbgs() << "\n";
678       });
679 
680       Optional<AffineMap> consumerLoopToProducerLoop =
681           getConsumerLoopToProducerLoopMap(dependence);
682       if (!consumerLoopToProducerLoop) {
683         op.emitRemark("failed to get map from consumer loop to producer loop");
684         return {};
685       }
686       // todo: This condition is only an implementation limitation. When fusing
687       // the operation, if the accesses in the producer/consumer are transposes
688       // of each other, the loop bounds for the tiled producer can be
689       // manipulated accordingly. This requires some additional bookkeeping in
690       // the implementation of tile+fuse that is defered to later.
691       if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
692         op.emitRemark("unhandled fusion when fusion requires permutation");
693         return {};
694       }
695 
696       std::set<unsigned> candidates;
697       for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
698         unsigned position = expr.cast<AffineDimExpr>().getPosition();
699         if (fusableLoops.count(position))
700           candidates.insert(position);
701       }
702       LLVM_DEBUG({
703         llvm::dbgs() << "\t candidates :";
704         for (unsigned i : candidates)
705           llvm::dbgs() << " " << i;
706         llvm::dbgs() << "\n";
707       });
708       if (candidates.empty())
709         return {};
710       std::swap(candidates, fusableLoops);
711     }
712   }
713 
714   return fusableLoops;
715 }
716 
717 /// Find all dependences that are fusable.
718 FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
719     ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
720   FusableOpDependencesTy fusableDependences;
721   // TODO: Currently fusion would not be legal if the fusable dependence is to
722   // the same producer but different indexing map in the consumer. Fix this, but
723   // in the meanwhile disallow such a fusion.
724   DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
725   for (LinalgOp op : reverse(ops)) {
726     for (auto operandIndex :
727          llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
728       Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
729           fusableDependence =
730               findFusableProducer(op, operandIndex, dependenceGraph);
731       if (!fusableDependence)
732         continue;
733       LinalgOp producerOp =
734           cast<LinalgOp>(fusableDependence->dependentOpView.op);
735       // Do not fuse dependences that are to operations not in the same basic
736       // block. This avoid moving fused operations across loops that might
737       // themselves carry dependency making the fusion illegal.
738       if (producerOp.getOperation()->getBlock() !=
739           op.getOperation()->getBlock()) {
740         op.emitRemark("unhandled fusion of ops in different basic blocks");
741         return FusableOpDependencesTy{};
742       }
743       // Make sure that the indexing map of the view used for fusion in the
744       // producer is a projected permutation.
745       unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
746       AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
747       if (!producerMap.isProjectedPermutation()) {
748         op.emitRemark(
749             "unhandled non permutation indexing map for fused view in "
750             "producer for operand at index ")
751             << operandIndex;
752         return FusableOpDependencesTy{};
753       }
754 
755       unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
756       AffineMap consumerMap = op.getIndexingMap(consumerIdx);
757       if (!consumerMap.isProjectedPermutation()) {
758         op.emitRemark(
759             "unhandled case where indexing map for fused view in the consumer "
760             "is "
761             "not a projected permuration while fusing at index ")
762             << operandIndex;
763         return FusableOpDependencesTy{};
764       }
765 
766       // Check if the producer is already a fusion candidate. Cannot fuse this
767       // dependence if it has a different indexing map when used in the
768       // consumer.
769       if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
770           fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
771         op.emitRemark(
772             "unhandled fusion to the same producer but with different "
773             "indexing maps");
774         return FusableOpDependencesTy{};
775       }
776       fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
777 
778       fusableDependences[producerOp.getOperation()].push_back(
779           *fusableDependence);
780     }
781   }
782   return fusableDependences;
783 }
784 
785 static bool isZero(Value v) {
786   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
787     return cst.getValue() == 0;
788   return false;
789 }
790 
791 /// Tile the fused loops in the root operation, by setting the tile sizes for
792 /// all other loops to zero (those will be tiled later).
793 static Optional<TiledLinalgOp> tileRootOperation(
794     OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
795     const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
796   SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
797   auto zero = std_constant_index(0);
798   for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
799     if (!fusedLoops.count(i))
800       tileSizes[i] = zero;
801   LinalgTilingOptions tileFusedLoopsOptions = options;
802   tileFusedLoopsOptions.setTileSizes(tileSizes);
803   return tileLinalgOp(builder, op, tileFusedLoopsOptions);
804 }
805 
806 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
807 /// to be a tiled operation such that it is valid to fuse all operations in
808 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
809 /// `tiledOp`.
810 static SmallVector<LinalgOp, 1>
811 fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
812                ArrayRef<LinalgOp> fusionCandidates,
813                const FusableOpDependencesTy &fusableDependences,
814                const std::set<unsigned> &fusedLoops) {
815   OpBuilder::InsertionGuard guard(builder);
816   builder.setInsertionPoint(tiledOp);
817   DenseMap<unsigned, Range> fusedLoopsAndRanges;
818   for (unsigned loop : fusedLoops) {
819     ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop);
820     fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
821         builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
822   }
823 
824   SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
825   for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
826     LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
827     fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
828     builder.setInsertionPoint(fusedOp);
829   }
830   return fusedOps;
831 }
832 
833 template <typename LoopType>
834 static Optional<TiledAndFusedLinalgOps>
835 tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
836                          const LinalgDependenceGraph &dependenceGraph,
837                          const LinalgTilingOptions &tilingOptions) {
838   if (ops.empty())
839     return llvm::None;
840   LinalgOp rootOp = ops.back();
841   for (auto op : enumerate(ops)) {
842     // TODO: Nothing in the fusion of sequence of ops is specific to
843     // buffers. This check can be removed after it is tested on tensors.
844     LinalgOp linalgOp = op.value();
845     if (!linalgOp.hasBufferSemantics()) {
846       linalgOp.emitError("tile and fuse only tested for buffer operation");
847       return llvm::None;
848     }
849   }
850   // TODO: Support interchange with tile + fuse. This might actually help do
851   // better fusion.
852   if (!tilingOptions.interchangeVector.empty()) {
853     rootOp.emitError("unable to handle tile and fuse with interchange");
854     return llvm::None;
855   }
856 
857   OpBuilder::InsertionGuard guard(builder);
858   builder.setInsertionPoint(rootOp);
859   ScopedContext scope(builder, rootOp.getLoc());
860 
861   // Find all the producers.
862   FusableOpDependencesTy fusableDependences =
863       findAllFusableDependences(ops, dependenceGraph);
864   if (fusableDependences.empty())
865     return llvm::None;
866 
867   TiledAndFusedLinalgOps ret;
868   // Find the loops that can be tiled and fused.
869   ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
870 
871   // If there are no fusable dependences or there are no tile+fusable loops,
872   // just return.
873   if (ret.fusedLoopDims.empty()) {
874     return llvm::None;
875   }
876 
877   // Tile the fused loops in the last operation in the list.
878   SmallVector<Value, 4> tileSizeVector =
879       tilingOptions.tileSizeComputationFunction(builder, rootOp);
880   Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
881       builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
882   if (!tiledRootOp) {
883     rootOp.emitError("failed to tile the fused loops");
884     return llvm::None;
885   }
886   ret.op = tiledRootOp->op;
887   ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
888 
889   // Fuse the other operations into the fused inter-tile loops produced above.
890   ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
891                                       fusableDependences, ret.fusedLoopDims);
892   return ret;
893 }
894 
895 Optional<TiledAndFusedLinalgOps>
896 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
897                                    const LinalgDependenceGraph &dependenceGraph,
898                                    const LinalgTilingOptions &tilingOptions) {
899   switch (tilingOptions.loopType) {
900   case LinalgTilingLoopType::Loops:
901     return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
902                                                 tilingOptions);
903   case LinalgTilingLoopType::ParallelLoops:
904     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
905         builder, ops, dependenceGraph, tilingOptions);
906   default:;
907   }
908   return llvm::None;
909 }
910