xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision 5ca20851e44c906a446e1860f01ee5b0f6f795a6)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/ADT/MapVector.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 
31 #include <set>
32 
33 #define DEBUG_TYPE "linalg-fusion"
34 
35 using namespace mlir;
36 using namespace mlir::edsc;
37 using namespace mlir::edsc::intrinsics;
38 using namespace mlir::linalg;
39 
40 using llvm::dbgs;
41 
42 /// Implements a simple high-level fusion pass on linalg structured operations.
43 ///
44 /// In each block, linalg ops are processed in reverse textual order.
45 /// Given a linalg op `O`, fusion occurs by:
46 ///   1. inspecting the linalg ops that write into the views read by `O`. There
47 ///      are 2 cases:
48 ///      a) buffer case: use the SSA value of the views and a simple alias
49 ///         analysis on subview ops to determine producer-consumer dependences;
50 ///      b) tensor case: use SSA use-def chains on subtensor ops;
51 ///   2. greedily fuse the linalg ops that produce the subview/subtensor.
52 ///   3. inspect the fused ops and determine whether they have other remaining
53 ///      LinalgOp uses. If not, then erase the original producing linalg op.
54 ///
55 /// More advanced use cases, analyses as well as profitability heuristics are
56 /// left for future work.
57 
58 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
59 // by `permutationMap`.
60 static void inferShapeComponents(AffineMap permutationMap,
61                                  ArrayRef<Range> loopRanges,
62                                  SmallVectorImpl<Value> &offsets,
63                                  SmallVectorImpl<Value> &sizes,
64                                  SmallVectorImpl<Value> &strides) {
65   assert(permutationMap.isProjectedPermutation() &&
66          "expected some subset of a permutation map");
67   SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
68   unsigned idx = 0;
69   for (AffineExpr e : permutationMap.getResults()) {
70     // loopToOperandRangesMaps are permutations-only, just swap indices.
71     unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
72     shapeRanges[idx++] = loopRanges[loopPos];
73   }
74   // Construct a new subshape for the tile.
75   unsigned rank = shapeRanges.size();
76   offsets.reserve(rank);
77   sizes.reserve(rank);
78   strides.reserve(rank);
79   for (auto r : shapeRanges) {
80     offsets.push_back(r.offset);
81     sizes.push_back(r.size);
82     strides.push_back(r.stride);
83   }
84 }
85 
86 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
87 // a subset of the original loop ranges of `op`.
88 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
89 // to the `loopRanges` in order to obtain view ranges.
90 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
91                                     ArrayRef<Range> loopRanges) {
92   SmallVector<Value, 8> clonedShapes;
93   clonedShapes.reserve(op.getNumShapedOperands());
94 
95   // Iterate over the shape operands in order.
96   // Extract the subranges from the linearized ranges.
97   for (auto en : llvm::enumerate(op.getShapedOperands())) {
98     unsigned shapedOperandIdx = en.index();
99     AffineMap map = op.getIndexingMap(shapedOperandIdx);
100     LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
101                             << " with indexingMap: " << map << "\n");
102     SmallVector<Value, 4> offsets, sizes, strides;
103     inferShapeComponents(map, loopRanges, offsets, sizes, strides);
104     Value shape = en.value();
105     Value sub = shape.getType().isa<MemRefType>()
106                     ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides)
107                           .getResult()
108                     : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
109                           .getResult();
110     clonedShapes.push_back(sub);
111   }
112   // Append the other operands.
113   auto operands = op.getAssumedNonShapedOperands();
114   clonedShapes.append(operands.begin(), operands.end());
115 
116   // Iterate over the results in order.
117   // Extract the subtensor type from the linearized range.
118   // Since we do not enforce any canonicalizations on the fly, this is always
119   // fully dynamic at construction time.
120   SmallVector<Type, 4> resultTypes;
121   resultTypes.reserve(op.getOperation()->getNumResults());
122   for (RankedTensorType t : op.getOutputTensorTypes()) {
123     unsigned rank = t.getRank();
124     SmallVector<int64_t, 4> staticOffsetsVector(
125         rank, ShapedType::kDynamicStrideOrOffset);
126     SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
127     SmallVector<int64_t, 4> staticStridesVector(
128         rank, ShapedType::kDynamicStrideOrOffset);
129     resultTypes.push_back(SubTensorOp::inferResultType(
130         t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
131         staticStridesVector));
132   }
133 
134   Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
135   // When the producer is an IndexedGenericOp, we have to transform its block
136   // IV arguments according to the tiling of the consumer, i.e. offset them by
137   // the values computed in `loopRanges`.
138   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
139     auto &block = indexedGenericOp.region().front();
140     OpBuilder::InsertionGuard g(b);
141     b.setInsertionPointToStart(&block);
142     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
143       Value oldIndex = block.getArgument(i);
144       // TODO: replace by an affine_apply.
145       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
146                                          loopRanges[i].offset);
147       oldIndex.replaceAllUsesExcept(newIndex,
148                                     SmallPtrSet<Operation *, 1>{newIndex});
149     }
150   }
151 
152   return clonedOp;
153 }
154 
155 struct ShapeDimension {
156   Value shape;
157   unsigned dimension;
158 };
159 
160 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
161 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
162 // guarantees at least one such dimension is found. If multiple candidates exist
163 // they must agree by construction (i.e. have the same size) and we just return
164 // the first one.
165 static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
166                                                 unsigned loopDepth) {
167   auto maps = op.indexing_maps();
168   // Iterate over the inputs and outputs in order.
169   // Extract the subranges from the linearized ranges.
170   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
171   for (auto en : llvm::enumerate(ios)) {
172     unsigned idx = en.index();
173     auto map = maps[idx].cast<AffineMapAttr>().getValue();
174     LLVM_DEBUG(llvm::dbgs()
175                << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
176     LLVM_DEBUG(llvm::dbgs()
177                << "getShapeDefiningLoopRange map: " << map << "\n");
178     Value shape = en.value();
179     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
180     for (auto en2 : llvm::enumerate(map.getResults())) {
181       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
182         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
183                                 << loopDepth << "\n");
184         LLVM_DEBUG(llvm::dbgs()
185                    << "getShapeDefiningLoopRange shape: " << shape << "\n");
186         return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
187       }
188     }
189   }
190   llvm_unreachable("Expect to be able to extract a shape defining loop range");
191 }
192 
193 /// Fuses the producer of `producerIdx` into the loop immediately enclosing
194 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it
195 /// is needed just before the `consumer.
196 ///
197 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
198 /// 2 cases:
199 ///   1. Buffer case: `producerIdx` is the index of the buffer in
200 ///      `producer.getOutputBuffers()`.
201 ///   2. Tensor case: `producerIdx` is the index of the tensor in
202 ///      `producer.getResults()`.
203 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
204                      LinalgOp consumer, unsigned consumerIdx) {
205   Operation *shapeProducingOp =
206       consumer.getShapedOperand(consumerIdx).getDefiningOp();
207   assert((isa<SubViewOp>(shapeProducingOp) ||
208           isa<SubTensorOp>(shapeProducingOp)) &&
209          "SubviewOp or SubTensorOp expected");
210 
211   // loopToOperandRangesMaps are permutations-only by construction:
212   //   we can always identify a data dimension with a (at least one) loop
213   //   dimension.
214   // TODO: extend this with range inference.
215   AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
216   LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
217                           << ", producer map: " << producerMap << "\n");
218 
219   unsigned nPar = producer.getNumParallelLoops();
220   unsigned nRed = producer.getNumReductionLoops();
221   unsigned nWin = producer.getNumWindowLoops();
222   SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
223 
224   // Iterate over dimensions identified by the producer map for `producerIdx`.
225   // This defines a subset of the loop ranges that we need to complete later.
226   auto loc = consumer.getLoc();
227   for (auto en : llvm::enumerate(producerMap.getResults())) {
228     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
229     loopRanges[posInProducerLoop] =
230         isa<SubViewOp>(shapeProducingOp)
231             ? cast<SubViewOp>(shapeProducingOp)
232                   .getOrCreateRanges(b, loc)[en.index()]
233             : cast<SubTensorOp>(shapeProducingOp)
234                   .getOrCreateRanges(b, loc)[en.index()];
235   }
236 
237   // Iterate over all dimensions. For the dimensions not identified by the
238   // producer map for `producerIdx`, we need to explicitly compute the shape
239   // that defines the loop ranges using the `producer`.
240   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
241     if (loopRanges[i].offset)
242       LLVM_DEBUG(llvm::dbgs()
243                  << "existing LoopRange: " << loopRanges[i] << "\n");
244     else {
245       auto shapeDim = getShapeDefiningLoopRange(producer, i);
246       loopRanges[i] = Range{std_constant_index(0),
247                             std_dim(shapeDim.shape, shapeDim.dimension),
248                             std_constant_index(1)};
249       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
250     }
251   }
252 
253   return cloneWithLoopRanges(b, loc, producer, loopRanges);
254 }
255 
256 // Encode structural fusion safety preconditions.
257 // Some of these will be lifted in the future with better analysis.
258 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
259                                           LinalgOp consumer) {
260   assert(producer.hasBufferSemantics() &&
261          "expected linalg op with buffer semantics");
262   assert(consumer.hasBufferSemantics() &&
263          "expected linalg op with buffer semantics");
264   if (producer.getNumOutputs() != 1) {
265     LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
266     return false;
267   }
268   // Only fuse when the producer block dominates.
269   DominanceInfo dom(producer.getOperation());
270   if (!dom.dominates(producer.getOperation()->getBlock(),
271                      consumer.getOperation()->getBlock())) {
272     LLVM_DEBUG(
273         llvm::dbgs()
274         << "\nNot structurally fusable (producer block does not dominate)");
275     return false;
276   }
277   return true;
278 }
279 
280 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
281                                              LinalgOp consumer,
282                                              Value consumedView,
283                                              LinalgOp producer) {
284   assert(producer.hasBufferSemantics() &&
285          "expected linalg op with buffer semantics");
286   assert(consumer.hasBufferSemantics() &&
287          "expected linalg op with buffer semantics");
288   // Make some simple structural checks that alleviate the need for more
289   // complex analyses.
290   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
291     LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
292                             << *producer.getOperation());
293     return false;
294   }
295   // Check for any interleaved write to consumedView.
296   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
297     LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
298                             << *producer.getOperation());
299     return false;
300   }
301   return true;
302 }
303 
304 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
305                                  LinalgOp consumer, Value consumedView,
306                                  LinalgOp producer) {
307   assert(producer.hasBufferSemantics() &&
308          "expected linalg op with buffer semantics");
309   assert(consumer.hasBufferSemantics() &&
310          "expected linalg op with buffer semantics");
311   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
312     return false;
313   // Check for any fusion-preventing dependence to any shape read/written that
314   // would violate dependences.
315   if (!graph.findCoveringDependences(producer, consumer).empty()) {
316     LLVM_DEBUG(llvm::dbgs()
317                << "\n***Not fusable due to an interleaved dependence:\t"
318                << *producer.getOperation());
319     return false;
320   }
321   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
322     // TODO: add a level of indirection to linalg.generic.
323     if (convOp.padding())
324       return false;
325   }
326   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
327     // TODO: add a level of indirection to linalg.generic.
328     if (convOp.padding())
329       return false;
330   }
331   return true;
332 }
333 
334 static bool isSameSubView(Value a, Value b) {
335   if (a == b)
336     return true;
337   auto sva = a.getDefiningOp<SubViewOp>();
338   auto svb = b.getDefiningOp<SubViewOp>();
339   if (!sva || !svb)
340     return false;
341   if (!isSameSubView(sva.getViewSource(), svb.getViewSource()))
342     return false;
343   if (sva.getType() != svb.getType())
344     return false;
345   if (sva.getNumOperands() != svb.getNumOperands())
346     return false;
347   if (sva.static_offsets() != svb.static_offsets())
348     return false;
349   if (sva.static_sizes() != svb.static_sizes())
350     return false;
351   if (sva.static_strides() != svb.static_strides())
352     return false;
353   /// Skip the "source" operand.
354   for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
355     if (sva.getOperand(idx) != svb.getOperand(idx))
356       return false;
357   return true;
358 }
359 
360 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
361 findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
362                     const LinalgDependenceGraph &dependenceGraph) {
363   // Only consider RAW and WAW atm.
364   for (auto depType : {
365            LinalgDependenceGraph::DependenceType::RAW,
366            LinalgDependenceGraph::DependenceType::WAW,
367        }) {
368     for (auto dependence : llvm::make_filter_range(
369              dependenceGraph.getDependencesInto(consumer, depType),
370              [consumerIdx](
371                  LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
372                return elem.indexingOpView.operandIndex == consumerIdx;
373              })) {
374       auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
375 
376       // Check that the dependence is indeed on the input `consumerIdx` view.
377       auto consumedView =
378           consumer.getBuffer(dependence.indexingOpView.operandIndex);
379       if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
380         continue;
381 
382       // Consumer consumes this view, `isStructurallyFusableProducer` also
383       // checks whether it is a strict subview of the producer view.
384       auto producedView =
385           producer.getBuffer(dependence.dependentOpView.operandIndex);
386       LLVM_DEBUG(llvm::dbgs()
387                  << "\n"
388                  << LinalgDependenceGraph::getDependenceTypeStr(depType)
389                  << "producer: " << *producer.getOperation()
390                  << " view: " << producedView << " output index: "
391                  << dependence.dependentOpView.operandIndex -
392                         producer.getNumInputs()
393                  << "\n");
394       (void)producedView;
395 
396       // Simple fusability checks.
397       if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
398         continue;
399 
400       return dependence;
401     }
402   }
403   return {};
404 }
405 
406 Optional<FusionInfo>
407 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
408                                    unsigned consumerIdx,
409                                    const LinalgDependenceGraph &graph) {
410   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
411       findFusableProducer(consumer, consumerIdx, graph);
412   if (!fusableDependence)
413     return {};
414 
415   LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
416   // If producer is already in the same block as consumer, we are done.
417   if (consumer.getOperation()->getBlock() ==
418       producerOp.getOperation()->getBlock())
419     return {};
420 
421   unsigned producerIdx = fusableDependence->dependentOpView.operandIndex -
422                          producerOp.getNumInputs();
423   Value consumerView = consumer.getShapedOperand(consumerIdx);
424 
425   // Must be a subview or a slice to guarantee there are loops we can fuse
426   // into.
427   auto subView = consumerView.getDefiningOp<SubViewOp>();
428   auto slice = consumerView.getDefiningOp<SliceOp>();
429   if (!subView && !slice) {
430     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
431     return {};
432   }
433 
434   // Fuse `producer` just before `consumer`.
435   OpBuilder::InsertionGuard g(b);
436   b.setInsertionPoint(consumer.getOperation());
437   ScopedContext scope(b, consumer.getLoc());
438   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
439 
440   auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
441   return FusionInfo{producerOp, fusedProducer};
442 }
443 
444 /// Walk back use-def chain through scf::For yields.
445 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
446 static void getProducerOfTensor(Value tensor, LinalgOp &producer,
447                                 unsigned &outputIndex) {
448   if (!tensor.getType().isa<RankedTensorType>())
449     return;
450 
451   while (true) {
452     if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
453       producer = linalgOp;
454       outputIndex = tensor.cast<OpResult>().getResultNumber();
455       return;
456     }
457     if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
458       tensor = subTensorOp.source();
459       continue;
460     }
461     if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
462       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
463         tensor = forOp.getResult(blockArg.getArgNumber());
464         continue;
465       }
466     }
467     return;
468   }
469 }
470 
471 Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
472                                                         LinalgOp consumer,
473                                                         unsigned consumerIdx) {
474   Value inputTensor = consumer.getInput(consumerIdx);
475   LinalgOp producerOp;
476   unsigned producerIdx;
477   getProducerOfTensor(inputTensor, producerOp, producerIdx);
478 
479   // Must be a subtensor to guarantee there are loops we can fuse into.
480   auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
481   if (!subTensor || !producerOp) {
482     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)");
483     return {};
484   }
485 
486   // If producer is already in the same block as consumer, we are done.
487   if (consumer.getOperation()->getBlock() ==
488       producerOp.getOperation()->getBlock())
489     return {};
490 
491   // Insert fused `producer` just before `consumer`.
492   OpBuilder::InsertionGuard g(b);
493   b.setInsertionPoint(consumer.getOperation());
494   ScopedContext scope(b, consumer.getLoc());
495   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
496   LinalgOp fusedProducer =
497       fuse(b, producerOp, producerIdx, consumer, consumerIdx);
498 
499   // Replace use.
500   // Canonicalizations are not guaranteed to have happened before constructing
501   // `fusedProducer`. In the tensor case this can result in temporary type
502   // mismatches. Insert a `tensor_cast` op to propagate the transformation
503   // invariant that types are compatible.
504   Value def = fusedProducer.getOperation()->getResult(producerIdx);
505   OpOperand &use = consumer.getOperation()->getOpOperand(consumerIdx);
506   Type consumerType = use.get().getType();
507   if (consumerType != def.getType())
508     def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
509   use.set(def);
510   return FusionInfo{producerOp, fusedProducer};
511 }
512 
513 /// Prune all dimensions that are of reduction iterator type from `map`.
514 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
515                                            AffineMap map) {
516   SmallVector<unsigned, 2> projectedDims;
517   for (auto attr : llvm::enumerate(iteratorTypes)) {
518     if (!isParallelIterator(attr.value()))
519       projectedDims.push_back(attr.index());
520   }
521   return getProjectedMap(map, projectedDims);
522 }
523 
524 using FusableOpDependencesTy = llvm::MapVector<
525     Operation *,
526     SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
527 
528 /// Returns the positions of the loop in `op` that can be tiled based on the
529 /// operations that are to be fused with it. For example, in a
530 ///
531 ///   linalg.matmul ins(%a, %b : ...) outs(%c : ...)
532 ///
533 /// if the producer of %a needs to be fused with this op, only the `i` loop of
534 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
535 /// fused, then no loops can be tiled while fusing. The conditions used are:
536 /// 1. Only parallel loops can be used for tile + fuse. Find the number of
537 ///    common outer parallel loops between the op and its producers being fused.
538 /// 2. Of the parallel loops only some can be fused. Only those loops can be
539 ///    fused such where the fusable loops iteration space only touches one tile
540 ///    of the fused operation. This is because the producer (which is writing
541 ///    the fused subview) has update semantics. To compute this,
542 ///    a. Find the mapping from iterations in the consumer that write to the
543 ///       same location as the iterations in the producer. To do so use
544 ///       - indexing map of the fused view in the consumer : consumerIndexMap
545 ///       - indexing map of the fused view in the producer : producerIndexMap
546 ///       consumerLoopToProducerLoop =
547 ///         inverse(producerIndexMap).compose(consumerIndexMap)
548 ///
549 /// Since an inverse computation is needed, we need to consider the projection
550 /// of the producerIndexMap w.r.t the parallel loops.  The actual fusable loops
551 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to
552 /// parallel loops and appear in the result of the map
553 ///
554 /// Example 1:
555 ///   linalg.fill(%c, %cst)
556 ///   linalg.matmul ins(%a, %b) outs(%c)
557 ///     Number of parallel loops : 2
558 ///     producerIndexMap = affine_map<(i, j) ->(i , j)>
559 ///     consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
560 ///     consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
561 ///     Fused dimensions : i, j
562 ///
563 /// Example 2:
564 ///   linalg.matmul ins(%a, %b) outs(%c)
565 ///   linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
566 ///                   iterator_types = ["parallel", "parallel"]}
567 ///     ins(%c) ...
568 ///
569 ///     Number of parallel loops = 2:
570 ///     producerIndexMap (projected to parallel loops) =
571 ///       affine_map<(i, j) -> (i, j)>
572 ///     consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
573 ///     Fused dimensions : i, j
574 ///
575 /// Example 3:
576 ///   linalg.copy(%s, %b)
577 ///   linalg.matmul ins(%a, %b) outs(%c)
578 ///
579 ///   Number of parallel loops = 2
580 ///   produceIndexMap : affine_map<(i, j) -> (i, j)>
581 ///   consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
582 ///     submap with only parallel loops = affine_map<(i, j) -> (j)>
583 ///   Fused dimensions : j
584 static std::set<unsigned>
585 collectTileAndFuseLoops(LinalgOp op,
586                         const FusableOpDependencesTy &fusableDependences) {
587   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
588     return linalgOp.iterator_types()
589         .getValue()
590         .take_while([](Attribute attr) -> bool {
591           return attr.cast<StringAttr>().getValue() ==
592                  getParallelIteratorTypeName();
593         })
594         .size();
595   };
596 
597   LLVM_DEBUG({
598     llvm::dbgs() << "Op : ";
599     op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
600     llvm::dbgs() << "\n";
601   });
602 
603   size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
604   for (auto dependence : fusableDependences) {
605     linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
606     numOuterParallelLoops =
607         std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer));
608   }
609 
610   std::set<unsigned> fusableLoops;
611   auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
612   fusableLoops.insert(range.begin(), range.end());
613   for (auto dependence : fusableDependences) {
614     LLVM_DEBUG({
615       llvm::dbgs() << "\t fusable :";
616       for (unsigned i : fusableLoops)
617         llvm::dbgs() << " " << i;
618       llvm::dbgs() << "\n";
619     });
620     linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first);
621 
622     assert(!dependence.second.empty() &&
623            "unexpected producer but not dependences");
624     AffineMap producerIndexingMap = producer.getIndexingMap(
625         dependence.second.front().dependentOpView.operandIndex);
626     AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
627         producer.iterator_types().getValue(), producerIndexingMap);
628     if (!prunedProducerIndexingMap.isPermutation())
629       return {};
630 
631     AffineMap consumerIndexingMap = op.getIndexingMap(
632         dependence.second.front().indexingOpView.operandIndex);
633     if (consumerIndexingMap.getNumResults() !=
634         prunedProducerIndexingMap.getNumResults())
635       return {};
636 
637     LLVM_DEBUG({
638       llvm::dbgs() << "\t producerMap : ";
639       producerIndexingMap.print(llvm::dbgs());
640       llvm::dbgs() << "  pruned : ";
641       prunedProducerIndexingMap.print(llvm::dbgs());
642       llvm::dbgs() << "\n";
643       llvm::dbgs() << "\t consumerMap : ";
644       consumerIndexingMap.print(llvm::dbgs());
645       llvm::dbgs() << "\n";
646     });
647 
648     AffineMap invProducerIndexMap =
649         inversePermutation(prunedProducerIndexingMap);
650     if (!invProducerIndexMap)
651       return {};
652 
653     AffineMap consumerLoopToProducerLoop =
654         invProducerIndexMap.compose(consumerIndexingMap);
655 
656     LLVM_DEBUG({
657       llvm::dbgs() << "\t consumerLoopToProducerLoop : ";
658       consumerLoopToProducerLoop.print(llvm::dbgs());
659     });
660 
661     std::set<unsigned> candidates;
662     for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) {
663       AffineDimExpr dimExpr = expr.dyn_cast<AffineDimExpr>();
664       if (!dimExpr)
665         continue;
666       unsigned position = dimExpr.getPosition();
667       if (fusableLoops.count(position))
668         candidates.insert(position);
669     }
670     LLVM_DEBUG({
671       llvm::dbgs() << "\t candidates :";
672       for (unsigned i : candidates)
673         llvm::dbgs() << " " << i;
674       llvm::dbgs() << "\n";
675     });
676     if (candidates.empty())
677       return {};
678     std::swap(candidates, fusableLoops);
679   }
680 
681   return fusableLoops;
682 }
683 
684 /// Find all dependences that are to be fusable.
685 static FusableOpDependencesTy
686 findAllFusableDependences(LinalgOp op,
687                           const LinalgDependenceGraph &dependenceGraph,
688                           const LinalgFusionOptions &fusionOptions) {
689   FusableOpDependencesTy fusableDependences;
690   // TODO: Currently fusion would not be legal if the fusable dependence is to
691   // the same producer but different indexing map in the consumer. Fix this, but
692   // in the meanwhile disallow such a fusion.
693   DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
694   for (auto operandIndex : fusionOptions.indicesToFuse) {
695     auto fusableDependence =
696         findFusableProducer(op, operandIndex, dependenceGraph);
697     if (!fusableDependence)
698       return FusableOpDependencesTy{};
699     LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
700     // Do not fuse dependences that are to operations not in the same basic
701     // block. This avoid moving fused operations across loops that might
702     // themselves carry dependency making the fusion illegal.
703     if (producerOp.getOperation()->getBlock() !=
704         op.getOperation()->getBlock()) {
705       op.emitRemark("unhandled fusion of ops in different basic blocks");
706       return FusableOpDependencesTy{};
707     }
708     // Make sure that the indexing map of the view used for fusion in the
709     // producer is a projected permutation.
710     unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
711     AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
712     if (!producerMap.isProjectedPermutation()) {
713       op.emitRemark("unhandled non permutation indexing map for fused view in "
714                     "producer for operand at index ")
715           << operandIndex;
716       return FusableOpDependencesTy{};
717     }
718 
719     unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
720     AffineMap consumerMap = op.getIndexingMap(consumerIdx);
721     if (!consumerMap.isProjectedPermutation()) {
722       op.emitRemark(
723           "unhandled case where indexing map for fused view in the consumer is "
724           "not a projected permutation while fusing at index ")
725           << operandIndex;
726       return FusableOpDependencesTy{};
727     }
728 
729     // Check if the producer is already a fusion candidate. Cannot fuse this
730     // dependence if it has a different indexing map when used in the consumer.
731     if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
732         fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
733       op.emitRemark("unhandled fusion to the same producer but with different "
734                     "indexing maps");
735       return FusableOpDependencesTy{};
736     }
737     fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
738 
739     fusableDependences[producerOp.getOperation()].push_back(*fusableDependence);
740   }
741   return fusableDependences;
742 }
743 
744 static bool isZero(Value v) {
745   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
746     return cst.getValue() == 0;
747   return false;
748 }
749 
750 template <typename LoopType>
751 static Optional<TiledAndFusedLinalgOps>
752 tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
753                          const LinalgDependenceGraph &dependenceGraph,
754                          const LinalgTilingOptions &tilingOptions,
755                          const LinalgFusionOptions &fusionOptions) {
756   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
757   // Some of the tiling options might not be supportable with tile and fuse.
758   // TODO: Support interchange with tile + fuse.
759   if (!tilingOptions.interchangeVector.empty()) {
760     op.emitError("unable to handle tile and fuse with interchange");
761     return llvm::None;
762   }
763 
764   OpBuilder::InsertionGuard g(rewriter);
765   rewriter.setInsertionPoint(op);
766   ScopedContext scope(rewriter, op.getLoc());
767 
768   // Find all the producers.
769   FusableOpDependencesTy fusableDependences =
770       findAllFusableDependences(op, dependenceGraph, fusionOptions);
771   if (fusableDependences.empty())
772     return llvm::None;
773 
774   // Enforce the convention that "tiling by zero" skips tiling a particular
775   // dimension. This convention is significantly simpler to handle instead of
776   // adjusting affine maps to account for missing dimensions.
777   auto nLoops = op.getNumLoops();
778   SmallVector<Value, 4> tileSizeVector =
779       tilingOptions.tileSizeComputationFunction(rewriter, op);
780   if (tileSizeVector.size() < nLoops) {
781     auto zero = std_constant_index(0);
782     tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
783   }
784 
785   TiledAndFusedLinalgOps ret;
786 
787   // Find the loops that can be tiled and fused.
788   std::set<unsigned> tileFuseLoops =
789       collectTileAndFuseLoops(op, fusableDependences);
790 
791   // If there are no fusable dependences or there are no tile+fusable loops,
792   // just return.
793   if (tileFuseLoops.empty()) {
794     return llvm::None;
795   }
796 
797   // Get the tile sizes for the first and second tiling steps. For the first
798   // step the tile size are set to zero for the loops that arent
799   // fused. Similarly for the second step, the tile sizes are set to zero for
800   // the loops that are fused. For example, if for the following input
801   //
802   // ```
803   //   linalg.add ins(%a, %b) outs(%c)
804   //   linalg.matmul ins(%d, %c) outs(%e)
805   // ```
806   //
807   // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
808   // respectively, and since only `j` can be tiled and fused. The tile sizes
809   // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
810   // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
811   // the tiled matmul generated by the first tiling step.
812   SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
813   for (auto tileSize : enumerate(tileSizeVector)) {
814     auto zero = std_constant_index(0);
815     if (tileFuseLoops.count(tileSize.index())) {
816       tileAndFuseSizes.push_back(tileSize.value());
817       tileSizes.push_back(zero);
818     } else {
819       tileSizes.push_back(tileSize.value());
820       tileAndFuseSizes.push_back(zero);
821     }
822   }
823 
824   // Tile for the loops that can be fused.
825   LinalgTilingOptions firstTilingOptions = tilingOptions;
826   firstTilingOptions.setTileSizes(tileAndFuseSizes);
827   Optional<TiledLinalgOp> firstTiledOp =
828       tileLinalgOp(rewriter, op, firstTilingOptions);
829   if (!firstTiledOp)
830     return llvm::None;
831   ret.op = firstTiledOp->op;
832   ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
833 
834   rewriter.setInsertionPoint(ret.op);
835   // Fuse the operands.
836   for (auto dependence : fusableDependences) {
837     LinalgOp producerOp = cast<LinalgOp>(dependence.first);
838     unsigned producerIdx =
839         dependence.second.front().dependentOpView.operandIndex;
840     unsigned consumerIdx =
841         dependence.second.front().indexingOpView.operandIndex;
842     LinalgOp fusedOp = fuse(rewriter, producerOp,
843                             producerOp.getOutputIndex(producerIdx).getValue(),
844                             ret.op, consumerIdx);
845     ret.fusedProducers.push_back(fusedOp);
846     ret.originalProducers.push_back(producerOp);
847   }
848 
849   if (!llvm::all_of(tileSizes, isZero)) {
850     // Tile the remaining loops of the root operation.
851     LinalgTilingOptions secondTilingOptions = tilingOptions;
852     // The distribution is done only for the tile+fused loops.
853     secondTilingOptions.distribution = llvm::None;
854     secondTilingOptions.setTileSizes(tileSizes);
855     Optional<TiledLinalgOp> secondTiledOp =
856         tileLinalgOp(rewriter, ret.op, secondTilingOptions);
857     if (!secondTiledOp)
858       return llvm::None;
859     ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
860                             secondTiledOp->loops.end());
861     rewriter.eraseOp(ret.op);
862     ret.op = secondTiledOp->op;
863   }
864 
865   return ret;
866 }
867 
868 Optional<TiledAndFusedLinalgOps>
869 mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
870                                    const LinalgDependenceGraph &dependenceGraph,
871                                    const LinalgTilingOptions &tilingOptions,
872                                    const LinalgFusionOptions &fusionOptions) {
873   switch (tilingOptions.loopType) {
874   case LinalgTilingLoopType::Loops:
875     return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
876                                                 tilingOptions, fusionOptions);
877   case LinalgTilingLoopType::ParallelLoops:
878     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
879         rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
880   default:;
881   }
882   return llvm::None;
883 }
884