xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision e65a5e5b00a37700a79e0a9f2fb1c1e60a2584bf)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/ADT/MapVector.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 
31 #include <set>
32 
33 #define DEBUG_TYPE "linalg-fusion"
34 
35 using namespace mlir;
36 using namespace mlir::edsc;
37 using namespace mlir::edsc::intrinsics;
38 using namespace mlir::linalg;
39 
40 using llvm::dbgs;
41 
42 /// Implements a simple high-level fusion pass on linalg structured operations.
43 ///
44 /// In each block, linalg ops are processed in reverse textual order.
45 /// Given a linalg op `O`, fusion occurs by:
46 ///   1. inspecting the linalg ops that write into the views read by `O`. There
47 ///      are 2 cases:
48 ///      a) buffer case: use the SSA value of the views and a simple alias
49 ///         analysis on subview ops to determine producer-consumer dependences;
50 ///      b) tensor case: use SSA use-def chains on subtensor ops;
51 ///   2. greedily fuse the linalg ops that produce the subview/subtensor.
52 ///   3. inspect the fused ops and determine whether they have other remaining
53 ///      LinalgOp uses. If not, then erase the original producing linalg op.
54 ///
55 /// More advanced use cases, analyses as well as profitability heuristics are
56 /// left for future work.
57 
58 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
59 // by `permutationMap`.
60 static void inferShapeComponents(AffineMap permutationMap,
61                                  ArrayRef<Range> loopRanges,
62                                  SmallVectorImpl<Value> &offsets,
63                                  SmallVectorImpl<Value> &sizes,
64                                  SmallVectorImpl<Value> &strides) {
65   assert(permutationMap.isProjectedPermutation() &&
66          "expected some subset of a permutation map");
67   SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
68   unsigned idx = 0;
69   for (AffineExpr e : permutationMap.getResults()) {
70     // loopToOperandRangesMaps are permutations-only, just swap indices.
71     unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
72     shapeRanges[idx++] = loopRanges[loopPos];
73   }
74   // Construct a new subshape for the tile.
75   unsigned rank = shapeRanges.size();
76   offsets.reserve(rank);
77   sizes.reserve(rank);
78   strides.reserve(rank);
79   for (auto r : shapeRanges) {
80     offsets.push_back(r.offset);
81     sizes.push_back(r.size);
82     strides.push_back(r.stride);
83   }
84 }
85 
86 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
87 // a subset of the original loop ranges of `op`.
88 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
89 // to the `loopRanges` in order to obtain view ranges.
90 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
91                                     ArrayRef<Range> loopRanges) {
92   SmallVector<Value, 8> clonedShapes;
93   clonedShapes.reserve(op.getNumShapedOperands());
94 
95   // Iterate over the shape operands in order.
96   // Extract the subranges from the linearized ranges.
97   for (auto en : llvm::enumerate(op.getShapedOperands())) {
98     unsigned shapedOperandIdx = en.index();
99     AffineMap map = op.getIndexingMap(shapedOperandIdx);
100     LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
101                             << " with indexingMap: " << map << "\n");
102     SmallVector<Value, 4> offsets, sizes, strides;
103     inferShapeComponents(map, loopRanges, offsets, sizes, strides);
104     Value shape = en.value();
105     Value sub = shape.getType().isa<MemRefType>()
106                     ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides)
107                           .getResult()
108                     : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
109                           .getResult();
110     clonedShapes.push_back(sub);
111   }
112   // Append the other operands.
113   auto operands = op.getAssumedNonShapedOperands();
114   clonedShapes.append(operands.begin(), operands.end());
115 
116   // Iterate over the results in order.
117   // Extract the subtensor type from the linearized range.
118   // Since we do not enforce any canonicalizations on the fly, this is always
119   // fully dynamic at construction time.
120   SmallVector<Type, 4> resultTypes;
121   resultTypes.reserve(op.getOperation()->getNumResults());
122   for (RankedTensorType t : op.getOutputTensorTypes()) {
123     unsigned rank = t.getRank();
124     SmallVector<int64_t, 4> staticOffsetsVector(
125         rank, ShapedType::kDynamicStrideOrOffset);
126     SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
127     SmallVector<int64_t, 4> staticStridesVector(
128         rank, ShapedType::kDynamicStrideOrOffset);
129     resultTypes.push_back(SubTensorOp::inferResultType(
130         t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
131         staticStridesVector));
132   }
133 
134   Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
135   // When the producer is an IndexedGenericOp, we have to transform its block
136   // IV arguments according to the tiling of the consumer, i.e. offset them by
137   // the values computed in `loopRanges`.
138   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
139     auto &block = indexedGenericOp.region().front();
140     OpBuilder::InsertionGuard g(b);
141     b.setInsertionPointToStart(&block);
142     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
143       Value oldIndex = block.getArgument(i);
144       // TODO: replace by an affine_apply.
145       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
146                                          loopRanges[i].offset);
147       oldIndex.replaceAllUsesExcept(newIndex,
148                                     SmallPtrSet<Operation *, 1>{newIndex});
149     }
150   }
151 
152   return clonedOp;
153 }
154 
155 struct ShapeDimension {
156   Value shape;
157   unsigned dimension;
158 };
159 
160 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
161 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
162 // guarantees at least one such dimension is found. If multiple candidates exist
163 // they must agree by construction (i.e. have the same size) and we just return
164 // the first one.
165 static ShapeDimension
166 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
167                           bool fromSubViewOpOnly = false) {
168   auto maps = op.indexing_maps();
169   // Iterate over the inputs and outputs in order.
170   // Extract the subranges from the linearized ranges.
171   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
172   for (auto en : llvm::enumerate(ios)) {
173     // The method `getRangeFromOperandShape` requires using SubViewOp or
174     // SubTensorOps. If the value isnt defined from there continue.
175     // todo: The method should be adapted to get the values from
176     // `ViewInterface`. The interface needs a `getOrCreateRanges` method which
177     // currently returns a `linalg.range`. The fix here is to move this op to
178     // `std` dialect and add the method to `ViewInterface`.
179     if (fromSubViewOpOnly &&
180         !isa_and_nonnull<SubViewOp, SubTensorOp>(en.value().getDefiningOp()))
181       continue;
182 
183     unsigned idx = en.index();
184     auto map = maps[idx].cast<AffineMapAttr>().getValue();
185     LLVM_DEBUG(llvm::dbgs()
186                << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
187     LLVM_DEBUG(llvm::dbgs()
188                << "getShapeDefiningLoopRange map: " << map << "\n");
189     Value shape = en.value();
190     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
191     for (auto en2 : llvm::enumerate(map.getResults())) {
192       auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
193       if (!dimExpr)
194         continue;
195       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
196         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
197                                 << loopDepth << "\n");
198         LLVM_DEBUG(llvm::dbgs()
199                    << "getShapeDefiningLoopRange shape: " << shape << "\n");
200         return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
201       }
202     }
203   }
204   llvm_unreachable("Expect to be able to extract a shape defining loop range");
205 }
206 
207 /// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
208 /// provides the loop range information for the fused loops. The rest are
209 /// obtained from the producer itself, since they are not tiled + fused.
210 static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
211                      const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
212 
213   unsigned nPar = producer.getNumParallelLoops();
214   unsigned nRed = producer.getNumReductionLoops();
215   unsigned nWin = producer.getNumWindowLoops();
216   SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
217   for (auto fusedLoops : fusedLoopsAndRanges)
218     loopRanges[fusedLoops.first] = fusedLoops.second;
219 
220   // Iterate over all dimensions. For the dimensions not identified by the
221   // producer map for `producerIdx`, we need to explicitly compute the shape
222   // that defines the loop ranges using the `producer`.
223   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
224     if (loopRanges[i].offset)
225       LLVM_DEBUG(llvm::dbgs()
226                  << "existing LoopRange: " << loopRanges[i] << "\n");
227     else {
228       auto shapeDim = getShapeDefiningLoopRange(producer, i);
229       loopRanges[i] = Range{std_constant_index(0),
230                             std_dim(shapeDim.shape, shapeDim.dimension),
231                             std_constant_index(1)};
232       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
233     }
234   }
235 
236   return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
237 }
238 
239 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
240 /// expected to be defined by a subview op or a subtensor op.
241 static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
242                                       Value shapedOperand, unsigned dim) {
243   Operation *shapeProducingOp = shapedOperand.getDefiningOp();
244   if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
245     return subViewOp.getOrCreateRanges(b, loc)[dim];
246   if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
247     return subTensorOp.getOrCreateRanges(b, loc)[dim];
248   llvm_unreachable("SubviewOp or SubTensorOp expected");
249 }
250 
251 /// Fuses the producer of `producerIdx` into the loop immediately enclosing
252 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it
253 /// is needed just before the `consumer.
254 ///
255 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
256 /// 2 cases:
257 ///   1. Buffer case: `producerIdx` is the index of the buffer in
258 ///      `producer.getOutputBuffers()`.
259 ///   2. Tensor case: `producerIdx` is the index of the tensor in
260 ///      `producer.getResults()`.
261 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
262                      LinalgOp consumer, unsigned consumerIdx) {
263   AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
264   LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
265                           << ", producer map: " << producerMap << "\n");
266   DenseMap<unsigned, Range> fusedLoopsAndRanges;
267   Location loc = consumer.getLoc();
268   Value shapedOperand = consumer.getShapedOperand(consumerIdx);
269   for (auto en : llvm::enumerate(producerMap.getResults())) {
270     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
271     fusedLoopsAndRanges[posInProducerLoop] =
272         getRangeFromOperandShape(b, loc, shapedOperand, en.index());
273   }
274   return fuse(b, producer, fusedLoopsAndRanges);
275 }
276 
277 // Encode structural fusion safety preconditions.
278 // Some of these will be lifted in the future with better analysis.
279 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
280                                           LinalgOp consumer) {
281   assert(producer.hasBufferSemantics() &&
282          "expected linalg op with buffer semantics");
283   assert(consumer.hasBufferSemantics() &&
284          "expected linalg op with buffer semantics");
285   if (producer.getNumOutputs() != 1) {
286     LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
287     return false;
288   }
289   // Only fuse when the producer block dominates.
290   DominanceInfo dom(producer.getOperation());
291   if (!dom.dominates(producer.getOperation()->getBlock(),
292                      consumer.getOperation()->getBlock())) {
293     LLVM_DEBUG(
294         llvm::dbgs()
295         << "\nNot structurally fusable (producer block does not dominate)");
296     return false;
297   }
298   return true;
299 }
300 
301 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
302                                              LinalgOp consumer,
303                                              Value consumedView,
304                                              LinalgOp producer) {
305   assert(producer.hasBufferSemantics() &&
306          "expected linalg op with buffer semantics");
307   assert(consumer.hasBufferSemantics() &&
308          "expected linalg op with buffer semantics");
309   // Make some simple structural checks that alleviate the need for more
310   // complex analyses.
311   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
312     LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
313                             << *producer.getOperation());
314     return false;
315   }
316   // Check for any interleaved write to consumedView.
317   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
318     LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
319                             << *producer.getOperation());
320     return false;
321   }
322   return true;
323 }
324 
325 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
326                                  LinalgOp consumer, Value consumedView,
327                                  LinalgOp producer) {
328   assert(producer.hasBufferSemantics() &&
329          "expected linalg op with buffer semantics");
330   assert(consumer.hasBufferSemantics() &&
331          "expected linalg op with buffer semantics");
332   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
333     return false;
334   // Check for any fusion-preventing dependence to any shape read/written that
335   // would violate dependences.
336   if (!graph.findCoveringDependences(producer, consumer).empty()) {
337     LLVM_DEBUG(llvm::dbgs()
338                << "\n***Not fusable due to an interleaved dependence:\t"
339                << *producer.getOperation());
340     return false;
341   }
342   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
343     // TODO: add a level of indirection to linalg.generic.
344     if (convOp.padding())
345       return false;
346   }
347   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
348     // TODO: add a level of indirection to linalg.generic.
349     if (convOp.padding())
350       return false;
351   }
352   return true;
353 }
354 
355 static bool isSameSubView(Value a, Value b) {
356   if (a == b)
357     return true;
358   auto sva = a.getDefiningOp<SubViewOp>();
359   auto svb = b.getDefiningOp<SubViewOp>();
360   if (!sva || !svb)
361     return false;
362   if (!isSameSubView(sva.getViewSource(), svb.getViewSource()))
363     return false;
364   if (sva.getType() != svb.getType())
365     return false;
366   if (sva.getNumOperands() != svb.getNumOperands())
367     return false;
368   if (sva.static_offsets() != svb.static_offsets())
369     return false;
370   if (sva.static_sizes() != svb.static_sizes())
371     return false;
372   if (sva.static_strides() != svb.static_strides())
373     return false;
374   /// Skip the "source" operand.
375   for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
376     if (sva.getOperand(idx) != svb.getOperand(idx))
377       return false;
378   return true;
379 }
380 
381 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
382 findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
383                     const LinalgDependenceGraph &dependenceGraph) {
384   // Only consider RAW and WAW atm.
385   for (auto depType : {
386            LinalgDependenceGraph::DependenceType::RAW,
387            LinalgDependenceGraph::DependenceType::WAW,
388        }) {
389     for (auto dependence : llvm::make_filter_range(
390              dependenceGraph.getDependencesInto(consumer, depType),
391              [consumerIdx](
392                  LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
393                return elem.indexingOpView.operandIndex == consumerIdx;
394              })) {
395       auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
396 
397       // Check that the dependence is indeed on the input `consumerIdx` view.
398       auto consumedView =
399           consumer.getBuffer(dependence.indexingOpView.operandIndex);
400       if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
401         continue;
402 
403       // Consumer consumes this view, `isStructurallyFusableProducer` also
404       // checks whether it is a strict subview of the producer view.
405       auto producedView =
406           producer.getBuffer(dependence.dependentOpView.operandIndex);
407       LLVM_DEBUG(llvm::dbgs()
408                  << "\n"
409                  << LinalgDependenceGraph::getDependenceTypeStr(depType)
410                  << "producer: " << *producer.getOperation()
411                  << " view: " << producedView << " output index: "
412                  << dependence.dependentOpView.operandIndex -
413                         producer.getNumInputs()
414                  << "\n");
415       (void)producedView;
416 
417       // Simple fusability checks.
418       if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
419         continue;
420 
421       return dependence;
422     }
423   }
424   return {};
425 }
426 
427 Optional<FusionInfo>
428 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
429                                    unsigned consumerIdx,
430                                    const LinalgDependenceGraph &graph) {
431   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
432       findFusableProducer(consumer, consumerIdx, graph);
433   if (!fusableDependence)
434     return {};
435 
436   LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
437   // If producer is already in the same block as consumer, we are done.
438   if (consumer.getOperation()->getBlock() ==
439       producerOp.getOperation()->getBlock())
440     return {};
441 
442   unsigned producerIdx = fusableDependence->dependentOpView.operandIndex -
443                          producerOp.getNumInputs();
444   Value consumerView = consumer.getShapedOperand(consumerIdx);
445 
446   // Must be a subview or a slice to guarantee there are loops we can fuse
447   // into.
448   auto subView = consumerView.getDefiningOp<SubViewOp>();
449   auto slice = consumerView.getDefiningOp<SliceOp>();
450   if (!subView && !slice) {
451     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
452     return {};
453   }
454 
455   // Fuse `producer` just before `consumer`.
456   OpBuilder::InsertionGuard g(b);
457   b.setInsertionPoint(consumer.getOperation());
458   ScopedContext scope(b, consumer.getLoc());
459   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
460 
461   auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
462   return FusionInfo{producerOp, fusedProducer};
463 }
464 
465 /// Walk back use-def chain through scf::For yields.
466 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
467 static void getProducerOfTensor(Value tensor, LinalgOp &producer,
468                                 unsigned &outputIndex) {
469   if (!tensor.getType().isa<RankedTensorType>())
470     return;
471 
472   while (true) {
473     if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
474       producer = linalgOp;
475       outputIndex = tensor.cast<OpResult>().getResultNumber();
476       return;
477     }
478     if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
479       tensor = subTensorOp.source();
480       continue;
481     }
482     if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
483       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
484         tensor = forOp.getResult(blockArg.getArgNumber());
485         continue;
486       }
487     }
488     return;
489   }
490 }
491 
492 Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
493                                                         LinalgOp consumer,
494                                                         unsigned consumerIdx) {
495   Value inputTensor = consumer.getInput(consumerIdx);
496   LinalgOp producerOp;
497   unsigned producerIdx;
498   getProducerOfTensor(inputTensor, producerOp, producerIdx);
499 
500   // Must be a subtensor to guarantee there are loops we can fuse into.
501   auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
502   if (!subTensor || !producerOp) {
503     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)");
504     return {};
505   }
506 
507   // If producer is already in the same block as consumer, we are done.
508   if (consumer.getOperation()->getBlock() ==
509       producerOp.getOperation()->getBlock())
510     return {};
511 
512   // Insert fused `producer` just before `consumer`.
513   OpBuilder::InsertionGuard g(b);
514   b.setInsertionPoint(consumer.getOperation());
515   ScopedContext scope(b, consumer.getLoc());
516   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
517   LinalgOp fusedProducer =
518       fuse(b, producerOp, producerIdx, consumer, consumerIdx);
519 
520   // Replace use.
521   // Canonicalizations are not guaranteed to have happened before constructing
522   // `fusedProducer`. In the tensor case this can result in temporary type
523   // mismatches. Insert a `tensor_cast` op to propagate the transformation
524   // invariant that types are compatible.
525   Value def = fusedProducer.getOperation()->getResult(producerIdx);
526   OpOperand &use = consumer.getOperation()->getOpOperand(consumerIdx);
527   Type consumerType = use.get().getType();
528   if (consumerType != def.getType())
529     def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
530   use.set(def);
531   return FusionInfo{producerOp, fusedProducer};
532 }
533 
534 /// Prune all dimensions that are of reduction iterator type from `map`.
535 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
536                                            AffineMap map) {
537   SmallVector<unsigned, 2> projectedDims;
538   for (auto attr : llvm::enumerate(iteratorTypes)) {
539     if (!isParallelIterator(attr.value()))
540       projectedDims.push_back(attr.index());
541   }
542   return getProjectedMap(map, projectedDims);
543 }
544 
545 using FusableOpDependencesTy = llvm::MapVector<
546     Operation *,
547     SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
548 
549 /// Returns the mapping from iterations in the consumer that write to the same
550 /// location as the iterations in the producer. To do so use
551 /// - indexing map of the fused view in the consumer : consumerIndexMap
552 /// - indexing map of the fused view in the producer : producerIndexMap
553 ///     consumerLoopToProducerLoop =
554 ///       inverse(producerIndexMap).compose(consumerIndexMap)
555 static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
556     LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
557   auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
558   AffineMap producerIndexingMap =
559       producer.getIndexingMap(dependence.dependentOpView.operandIndex);
560   auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
561   AffineMap consumerIndexingMap =
562       consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
563 
564   AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
565       producer.iterator_types().getValue(), producerIndexingMap);
566   if (!prunedProducerIndexingMap.isPermutation())
567     return None;
568 
569   if (consumerIndexingMap.getNumResults() !=
570       prunedProducerIndexingMap.getNumResults())
571     return None;
572 
573   LLVM_DEBUG({
574     llvm::dbgs() << "\t producerMap : ";
575     producerIndexingMap.print(llvm::dbgs());
576     llvm::dbgs() << "  pruned : ";
577     prunedProducerIndexingMap.print(llvm::dbgs());
578     llvm::dbgs() << "\n";
579     llvm::dbgs() << "\t consumerMap : ";
580     consumerIndexingMap.print(llvm::dbgs());
581     llvm::dbgs() << "\n";
582   });
583 
584   AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
585   if (!invProducerIndexMap)
586     return None;
587 
588   return invProducerIndexMap.compose(consumerIndexingMap);
589 }
590 
591 /// Given a projected permutation `map`, returns true if the map changes the
592 /// order in which the fused loop dimension appear.
593 static bool doesTransposeAccess(AffineMap map,
594                                 const std::set<unsigned> &fusableLoops) {
595   Optional<unsigned> lastFusableLoop;
596   for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
597          return expr.cast<AffineDimExpr>().getPosition();
598        })) {
599     if (!fusableLoops.count(pos))
600       continue;
601     if (!lastFusableLoop) {
602       lastFusableLoop = pos;
603       continue;
604     }
605     if (pos <= lastFusableLoop.getValue())
606       return true;
607     lastFusableLoop = pos;
608   }
609   return false;
610 }
611 
612 /// Returns the positions of the loop in `op` that can be tiled based on the
613 /// operations that are to be fused with it. For example, in a
614 ///
615 ///   linalg.matmul ins(%a, %b : ...) outs(%c : ...)
616 ///
617 /// if the producer of %a needs to be fused with this op, only the `i` loop of
618 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
619 /// fused, then no loops can be tiled while fusing. The conditions used are:
620 /// 1. Only parallel loops can be used for tile + fuse. Find the number of
621 ///    common outer parallel loops between the op and its producers being fused.
622 /// 2. Of the parallel loops only some can be fused. Only those loops can be
623 ///    fused such where the fusable loops iteration space only touches one tile
624 ///    of the fused operation. This is because the producer (which is writing
625 ///    the fused subview) has update semantics.
626 ///
627 /// Since an inverse computation is needed, we need to consider the projection
628 /// of the producerIndexMap w.r.t the parallel loops.  The actual fusable loops
629 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to
630 /// parallel loops and appear in the result of the map
631 ///
632 /// Example 1:
633 ///   linalg.fill(%c, %cst)
634 ///   linalg.matmul ins(%a, %b) outs(%c)
635 ///     Number of parallel loops : 2
636 ///     producerIndexMap = affine_map<(i, j) ->(i , j)>
637 ///     consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
638 ///     consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
639 ///     Fused dimensions : i, j
640 ///
641 /// Example 2:
642 ///   linalg.matmul ins(%a, %b) outs(%c)
643 ///   linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
644 ///                   iterator_types = ["parallel", "parallel"]}
645 ///     ins(%c) ...
646 ///
647 ///     Number of parallel loops = 2:
648 ///     producerIndexMap (projected to parallel loops) =
649 ///       affine_map<(i, j) -> (i, j)>
650 ///     consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
651 ///     Fused dimensions : i, j
652 ///
653 /// Example 3:
654 ///   linalg.copy(%s, %b)
655 ///   linalg.matmul ins(%a, %b) outs(%c)
656 ///
657 ///   Number of parallel loops = 2
658 ///   produceIndexMap : affine_map<(i, j) -> (i, j)>
659 ///   consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
660 ///     submap with only parallel loops = affine_map<(i, j) -> (j)>
661 ///   Fused dimensions : j
662 static std::set<unsigned>
663 collectFusableLoops(ArrayRef<LinalgOp> ops,
664                     const FusableOpDependencesTy &fusableDependences) {
665   assert(!ops.empty());
666   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
667     return linalgOp.iterator_types()
668         .getValue()
669         .take_while([](Attribute attr) -> bool {
670           return attr.cast<StringAttr>().getValue() ==
671                  getParallelIteratorTypeName();
672         })
673         .size();
674   };
675 
676   size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
677   for (auto op : ops.drop_back()) {
678     numOuterParallelLoops =
679         std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
680   }
681 
682   std::set<unsigned> fusableLoops;
683   auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
684   fusableLoops.insert(range.begin(), range.end());
685 
686   for (auto op : reverse(ops)) {
687     for (auto dependence : fusableDependences.lookup(op)) {
688       LLVM_DEBUG({
689         llvm::dbgs() << "\t fusable :";
690         for (unsigned i : fusableLoops)
691           llvm::dbgs() << " " << i;
692         llvm::dbgs() << "\n";
693       });
694 
695       Optional<AffineMap> consumerLoopToProducerLoop =
696           getConsumerLoopToProducerLoopMap(dependence);
697       if (!consumerLoopToProducerLoop) {
698         op.emitRemark("failed to get map from consumer loop to producer loop");
699         return {};
700       }
701       // todo: This condition is only an implementation limitation. When fusing
702       // the operation, if the accesses in the producer/consumer are transposes
703       // of each other, the loop bounds for the tiled producer can be
704       // manipulated accordingly. This requires some additional bookkeeping in
705       // the implementation of tile+fuse that is defered to later.
706       if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
707         op.emitRemark("unhandled fusion when fusion requires permutation");
708         return {};
709       }
710 
711       std::set<unsigned> candidates;
712       for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
713         unsigned position = expr.cast<AffineDimExpr>().getPosition();
714         if (fusableLoops.count(position))
715           candidates.insert(position);
716       }
717       LLVM_DEBUG({
718         llvm::dbgs() << "\t candidates :";
719         for (unsigned i : candidates)
720           llvm::dbgs() << " " << i;
721         llvm::dbgs() << "\n";
722       });
723       if (candidates.empty())
724         return {};
725       std::swap(candidates, fusableLoops);
726     }
727   }
728 
729   return fusableLoops;
730 }
731 
732 /// Find all dependences that are to be fusable.
733 static FusableOpDependencesTy
734 findAllFusableDependences(ArrayRef<LinalgOp> ops,
735                           const LinalgDependenceGraph &dependenceGraph) {
736   FusableOpDependencesTy fusableDependences;
737   // TODO: Currently fusion would not be legal if the fusable dependence is to
738   // the same producer but different indexing map in the consumer. Fix this, but
739   // in the meanwhile disallow such a fusion.
740   DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
741   for (LinalgOp op : reverse(ops)) {
742     for (auto operandIndex :
743          llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
744       Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
745           fusableDependence =
746               findFusableProducer(op, operandIndex, dependenceGraph);
747       if (!fusableDependence)
748         continue;
749       LinalgOp producerOp =
750           cast<LinalgOp>(fusableDependence->dependentOpView.op);
751       // Do not fuse dependences that are to operations not in the same basic
752       // block. This avoid moving fused operations across loops that might
753       // themselves carry dependency making the fusion illegal.
754       if (producerOp.getOperation()->getBlock() !=
755           op.getOperation()->getBlock()) {
756         op.emitRemark("unhandled fusion of ops in different basic blocks");
757         return FusableOpDependencesTy{};
758       }
759       // Make sure that the indexing map of the view used for fusion in the
760       // producer is a projected permutation.
761       unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
762       AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
763       if (!producerMap.isProjectedPermutation()) {
764         op.emitRemark(
765             "unhandled non permutation indexing map for fused view in "
766             "producer for operand at index ")
767             << operandIndex;
768         return FusableOpDependencesTy{};
769       }
770 
771       unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
772       AffineMap consumerMap = op.getIndexingMap(consumerIdx);
773       if (!consumerMap.isProjectedPermutation()) {
774         op.emitRemark(
775             "unhandled case where indexing map for fused view in the consumer "
776             "is "
777             "not a projected permuration while fusing at index ")
778             << operandIndex;
779         return FusableOpDependencesTy{};
780       }
781 
782       // Check if the producer is already a fusion candidate. Cannot fuse this
783       // dependence if it has a different indexing map when used in the
784       // consumer.
785       if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
786           fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
787         op.emitRemark(
788             "unhandled fusion to the same producer but with different "
789             "indexing maps");
790         return FusableOpDependencesTy{};
791       }
792       fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
793 
794       fusableDependences[producerOp.getOperation()].push_back(
795           *fusableDependence);
796     }
797   }
798   return fusableDependences;
799 }
800 
801 static bool isZero(Value v) {
802   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
803     return cst.getValue() == 0;
804   return false;
805 }
806 
807 /// Tile the fused loops in the root operation, by setting the tile sizes for
808 /// all other loops to zero (those will be tiled later).
809 static Optional<TiledLinalgOp> tileRootOperation(
810     OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
811     const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
812   SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
813   auto zero = std_constant_index(0);
814   for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
815     if (!fusedLoops.count(i))
816       tileSizes[i] = zero;
817   LinalgTilingOptions tileFusedLoopsOptions = options;
818   tileFusedLoopsOptions.setTileSizes(tileSizes);
819   return tileLinalgOp(builder, op, tileFusedLoopsOptions);
820 }
821 
822 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
823 /// to be a tiled operation such that it is valid to fuse all operations in
824 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
825 /// `tiledOp`.
826 static SmallVector<LinalgOp, 1>
827 fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
828                ArrayRef<LinalgOp> fusionCandidates,
829                const FusableOpDependencesTy &fusableDependences,
830                const std::set<unsigned> &fusedLoops) {
831   OpBuilder::InsertionGuard guard(builder);
832   builder.setInsertionPoint(tiledOp);
833   DenseMap<unsigned, Range> fusedLoopsAndRanges;
834   for (unsigned loop : fusedLoops) {
835     ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
836     fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
837         builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
838   }
839   SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
840   for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
841     LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
842     fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
843     builder.setInsertionPoint(fusedOp);
844   }
845   return fusedOps;
846 }
847 
848 template <typename LoopType>
849 static Optional<TiledAndFusedLinalgOps>
850 tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
851                          const LinalgDependenceGraph &dependenceGraph,
852                          const LinalgTilingOptions &tilingOptions) {
853   if (ops.empty())
854     return llvm::None;
855   LinalgOp rootOp = ops.back();
856   for (auto op : enumerate(ops)) {
857     // TODO: Nothing in the fusion of sequence of ops is specific to
858     // buffers. This check can be removed after it is tested on tensors.
859     LinalgOp linalgOp = op.value();
860     if (!linalgOp.hasBufferSemantics()) {
861       linalgOp.emitError("tile and fuse only tested for buffer operation");
862       return llvm::None;
863     }
864   }
865   // TODO: Support interchange with tile + fuse. This might actually help do
866   // better fusion.
867   if (!tilingOptions.interchangeVector.empty()) {
868     rootOp.emitError("unable to handle tile and fuse with interchange");
869     return llvm::None;
870   }
871 
872   OpBuilder::InsertionGuard guard(builder);
873   builder.setInsertionPoint(rootOp);
874   ScopedContext scope(builder, rootOp.getLoc());
875 
876   // Find all the producers.
877   FusableOpDependencesTy fusableDependences =
878       findAllFusableDependences(ops, dependenceGraph);
879   if (fusableDependences.empty())
880     return llvm::None;
881 
882   TiledAndFusedLinalgOps ret;
883   // Find the loops that can be tiled and fused.
884   ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
885 
886   // If there are no fusable dependences or there are no tile+fusable loops,
887   // just return.
888   if (ret.fusedLoopDims.empty()) {
889     return llvm::None;
890   }
891 
892   // Tile the fused loops in the last operation in the list.
893   SmallVector<Value, 4> tileSizeVector =
894       tilingOptions.tileSizeComputationFunction(builder, rootOp);
895   Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
896       builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
897   if (!tiledRootOp) {
898     rootOp.emitError("failed to tile the fused loops");
899     return llvm::None;
900   }
901   ret.op = tiledRootOp->op;
902   ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
903 
904   // Fuse the other operations into the fused inter-tile loops produced above.
905   ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
906                                       fusableDependences, ret.fusedLoopDims);
907   return ret;
908 }
909 
910 Optional<TiledAndFusedLinalgOps>
911 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
912                                    const LinalgDependenceGraph &dependenceGraph,
913                                    const LinalgTilingOptions &tilingOptions) {
914   switch (tilingOptions.loopType) {
915   case LinalgTilingLoopType::Loops:
916     return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
917                                                 tilingOptions);
918   case LinalgTilingLoopType::ParallelLoops:
919     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
920         builder, ops, dependenceGraph, tilingOptions);
921   default:;
922   }
923   return llvm::None;
924 }
925