xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision 9b17bf2e54c71b36bf28fbab05698fb73ea8dda9)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 
31 #define DEBUG_TYPE "linalg-fusion"
32 
33 using namespace mlir;
34 using namespace mlir::edsc;
35 using namespace mlir::edsc::intrinsics;
36 using namespace mlir::linalg;
37 
38 using llvm::dbgs;
39 
40 /// Implements a simple high-level fusion pass on linalg structured operations.
41 ///
42 /// In each block, linalg ops are processed in reverse textual order.
43 /// Given a linalg op `O`, fusion occurs by:
44 ///   1. inspecting the linalg ops that write into the views read by `O`. There
45 ///      are 2 cases:
46 ///      a) buffer case: use the SSA value of the views and a simple alias
47 ///         analysis on subview ops to determine producer-consumer dependences;
48 ///      b) tensor case: use SSA use-def chains on subtensor ops;
49 ///   2. greedily fuse the linalg ops that produce the subview/subtensor.
50 ///   3. inspect the fused ops and determine whether they have other remaining
51 ///      LinalgOp uses. If not, then erase the original producing linalg op.
52 ///
53 /// More advanced use cases, analyses as well as profitability heuristics are
54 /// left for future work.
55 
56 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
57 // by `permutationMap`.
58 static void inferShapeComponents(AffineMap permutationMap,
59                                  ArrayRef<Range> loopRanges,
60                                  SmallVectorImpl<Value> &offsets,
61                                  SmallVectorImpl<Value> &sizes,
62                                  SmallVectorImpl<Value> &strides) {
63   assert(permutationMap.isProjectedPermutation() &&
64          "expected some subset of a permutation map");
65   SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
66   unsigned idx = 0;
67   for (AffineExpr e : permutationMap.getResults()) {
68     // loopToOperandRangesMaps are permutations-only, just swap indices.
69     unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
70     shapeRanges[idx++] = loopRanges[loopPos];
71   }
72   // Construct a new subshape for the tile.
73   unsigned rank = shapeRanges.size();
74   offsets.reserve(rank);
75   sizes.reserve(rank);
76   strides.reserve(rank);
77   for (auto r : shapeRanges) {
78     offsets.push_back(r.offset);
79     sizes.push_back(r.size);
80     strides.push_back(r.stride);
81   }
82 }
83 
84 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
85 // a subset of the original loop ranges of `op`.
86 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
87 // to the `loopRanges` in order to obtain view ranges.
88 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
89                                     ArrayRef<Range> loopRanges) {
90   SmallVector<Value, 8> clonedShapes;
91   clonedShapes.reserve(op.getNumShapedOperands());
92 
93   // Iterate over the shape operands in order.
94   // Extract the subranges from the linearized ranges.
95   for (auto en : llvm::enumerate(op.getShapedOperands())) {
96     unsigned shapedOperandIdx = en.index();
97     AffineMap map = op.getIndexingMap(shapedOperandIdx);
98     LLVM_DEBUG(dbgs() << "shapedOperandIdx: " << shapedOperandIdx
99                       << " with indexingMap: " << map << "\n");
100     SmallVector<Value, 4> offsets, sizes, strides;
101     inferShapeComponents(map, loopRanges, offsets, sizes, strides);
102     Value shape = en.value();
103     Value sub = shape.getType().isa<MemRefType>()
104                     ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides)
105                           .getResult()
106                     : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
107                           .getResult();
108     clonedShapes.push_back(sub);
109   }
110   // Append the other operands.
111   auto operands = op.getAssumedNonShapedOperands();
112   clonedShapes.append(operands.begin(), operands.end());
113 
114   // Iterate over the results in order.
115   // Extract the subtensor type from the linearized range.
116   // Since we do not enforce any canonicalizations on the fly, this is always
117   // fully dynamic at construction time.
118   SmallVector<Type, 4> resultTypes;
119   resultTypes.reserve(op.getOperation()->getNumResults());
120   for (RankedTensorType t : op.getOutputTensorTypes()) {
121     unsigned rank = t.getRank();
122     SmallVector<int64_t, 4> staticOffsetsVector(
123         rank, ShapedType::kDynamicStrideOrOffset);
124     SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
125     SmallVector<int64_t, 4> staticStridesVector(
126         rank, ShapedType::kDynamicStrideOrOffset);
127     resultTypes.push_back(SubTensorOp::inferResultType(
128         t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
129         staticStridesVector));
130   }
131 
132   Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
133   // When the producer is an IndexedGenericOp, we have to transform its block
134   // IV arguments according to the tiling of the consumer, i.e. offset them by
135   // the values computed in `loopRanges`.
136   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
137     auto &block = indexedGenericOp.region().front();
138     OpBuilder::InsertionGuard g(b);
139     b.setInsertionPointToStart(&block);
140     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
141       Value oldIndex = block.getArgument(i);
142       // TODO: replace by an affine_apply.
143       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
144                                          loopRanges[i].offset);
145       oldIndex.replaceAllUsesExcept(newIndex,
146                                     SmallPtrSet<Operation *, 1>{newIndex});
147     }
148   }
149 
150   return clonedOp;
151 }
152 
153 struct ShapeDimension {
154   Value shape;
155   unsigned dimension;
156 };
157 
158 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
159 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
160 // guarantees at least one such dimension is found. If multiple candidates exist
161 // they must agree by construction (i.e. have the same size) and we just return
162 // the first one.
163 static ShapeDimension getShapeDefiningLoopRange(LinalgOp op,
164                                                 unsigned loopDepth) {
165   auto maps = op.indexing_maps();
166   // Iterate over the inputs and outputs in order.
167   // Extract the subranges from the linearized ranges.
168   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
169   for (auto en : llvm::enumerate(ios)) {
170     unsigned idx = en.index();
171     auto map = maps[idx].cast<AffineMapAttr>().getValue();
172     LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
173     LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange map: " << map << "\n");
174     Value shape = en.value();
175     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
176     for (auto en2 : llvm::enumerate(map.getResults())) {
177       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
178         LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange loopDepth: "
179                           << loopDepth << "\n");
180         LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange shape: " << shape
181                           << "\n");
182         return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
183       }
184     }
185   }
186   llvm_unreachable("Expect to be able to extract a shape defining loop range");
187 }
188 
189 /// Fuses the producer of `producerIdx` into the loop immediately enclosing
190 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it
191 /// is needed just before the `consumer.
192 ///
193 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
194 /// 2 cases:
195 ///   1. Buffer case: `producerIdx` is the index of the buffer in
196 ///      `producer.getOutputBuffers()`.
197 ///   2. Tensor case: `producerIdx` is the index of the tensor in
198 ///      `producer.getResults()`.
199 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
200                      LinalgOp consumer, unsigned consumerIdx) {
201   Operation *shapeProducingOp =
202       consumer.getShapedOperand(consumerIdx).getDefiningOp();
203   assert((isa<SubViewOp>(shapeProducingOp) ||
204           isa<SubTensorOp>(shapeProducingOp)) &&
205          "SubviewOp or SubTensorOp expected");
206 
207   // loopToOperandRangesMaps are permutations-only by construction:
208   //   we can always identify a data dimension with a (at least one) loop
209   //   dimension.
210   // TODO: extend this with range inference.
211   AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
212   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
213                     << ", producer map: " << producerMap << "\n");
214 
215   unsigned nPar = producer.getNumParallelLoops();
216   unsigned nRed = producer.getNumReductionLoops();
217   unsigned nWin = producer.getNumWindowLoops();
218   SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
219 
220   // Iterate over dimensions identified by the producer map for `producerIdx`.
221   // This defines a subset of the loop ranges that we need to complete later.
222   auto loc = consumer.getLoc();
223   for (auto en : llvm::enumerate(producerMap.getResults())) {
224     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
225     loopRanges[posInProducerLoop] =
226         isa<SubViewOp>(shapeProducingOp)
227             ? cast<SubViewOp>(shapeProducingOp)
228                   .getOrCreateRanges(b, loc)[en.index()]
229             : cast<SubTensorOp>(shapeProducingOp)
230                   .getOrCreateRanges(b, loc)[en.index()];
231   }
232 
233   // Iterate over all dimensions. For the dimensions not identified by the
234   // producer map for `producerIdx`, we need to explicitly compute the shape
235   // that defines the loop ranges using the `producer`.
236   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
237     if (loopRanges[i].offset)
238       LLVM_DEBUG(llvm::dbgs()
239                  << "existing LoopRange: " << loopRanges[i] << "\n");
240     else {
241       auto shapeDim = getShapeDefiningLoopRange(producer, i);
242       loopRanges[i] = Range{std_constant_index(0),
243                             std_dim(shapeDim.shape, shapeDim.dimension),
244                             std_constant_index(1)};
245       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
246     }
247   }
248 
249   return cloneWithLoopRanges(b, loc, producer, loopRanges);
250 }
251 
252 // Encode structural fusion safety preconditions.
253 // Some of these will be lifted in the future with better analysis.
254 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
255                                           LinalgOp consumer) {
256   assert(producer.hasBufferSemantics() &&
257          "expected linalg op with buffer semantics");
258   assert(consumer.hasBufferSemantics() &&
259          "expected linalg op with buffer semantics");
260   if (producer.getNumOutputs() != 1) {
261     LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
262     return false;
263   }
264   // Only fuse when the producer block dominates.
265   DominanceInfo dom(producer.getOperation());
266   if (!dom.dominates(producer.getOperation()->getBlock(),
267                      consumer.getOperation()->getBlock())) {
268     LLVM_DEBUG(
269         dbgs()
270         << "\nNot structurally fusable (producer block does not dominate)");
271     return false;
272   }
273   return true;
274 }
275 
276 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
277                                              LinalgOp consumer,
278                                              Value consumedView,
279                                              LinalgOp producer) {
280   assert(producer.hasBufferSemantics() &&
281          "expected linalg op with buffer semantics");
282   assert(consumer.hasBufferSemantics() &&
283          "expected linalg op with buffer semantics");
284   // Make some simple structural checks that alleviate the need for more
285   // complex analyses.
286   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
287     LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t"
288                       << *producer.getOperation());
289     return false;
290   }
291   // Check for any interleaved write to consumedView.
292   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
293     LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t"
294                       << *producer.getOperation());
295     return false;
296   }
297   return true;
298 }
299 
300 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
301                                  LinalgOp consumer, Value consumedView,
302                                  LinalgOp producer) {
303   assert(producer.hasBufferSemantics() &&
304          "expected linalg op with buffer semantics");
305   assert(consumer.hasBufferSemantics() &&
306          "expected linalg op with buffer semantics");
307   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
308     return false;
309   // Check for any fusion-preventing dependence to any shape read/written that
310   // would violate dependences.
311   if (!graph.findCoveringDependences(producer, consumer).empty()) {
312     LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t"
313                       << *producer.getOperation());
314     return false;
315   }
316   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
317     // TODO: add a level of indirection to linalg.generic.
318     if (convOp.padding())
319       return false;
320   }
321   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
322     // TODO: add a level of indirection to linalg.generic.
323     if (convOp.padding())
324       return false;
325   }
326   return true;
327 }
328 
329 static bool isSameSubView(Value a, Value b) {
330   if (a == b)
331     return true;
332   auto sva = a.getDefiningOp<SubViewOp>();
333   auto svb = b.getDefiningOp<SubViewOp>();
334   if (!sva || !svb)
335     return false;
336   if (!isSameSubView(sva.getViewSource(), svb.getViewSource()))
337     return false;
338   if (sva.getType() != svb.getType())
339     return false;
340   if (sva.getNumOperands() != svb.getNumOperands())
341     return false;
342   if (sva.static_offsets() != svb.static_offsets())
343     return false;
344   if (sva.static_sizes() != svb.static_sizes())
345     return false;
346   if (sva.static_strides() != svb.static_strides())
347     return false;
348   /// Skip the "source" operand.
349   for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
350     if (sva.getOperand(idx) != svb.getOperand(idx))
351       return false;
352   return true;
353 }
354 
355 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
356 findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
357                     const LinalgDependenceGraph &dependenceGraph) {
358   // Only consider RAW and WAW atm.
359   for (auto depType : {
360            LinalgDependenceGraph::DependenceType::RAW,
361            LinalgDependenceGraph::DependenceType::WAW,
362        }) {
363     for (auto dependence :
364          dependenceGraph.getDependencesInto(consumer, depType)) {
365       auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
366 
367       // Check that the dependence is indeed on the input `consumerIdx` view.
368       auto consumedView = dependence.indexingView;
369       if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
370         continue;
371 
372       // Consumer consumes this view, `isStructurallyFusableProducer` also
373       // checks whether it is a strict subview of the producer view.
374       auto producedView = dependence.dependentOpView.view;
375       auto producerIdx =
376           producer.getIndexOfOutputBuffer(producedView).getValue();
377       // `consumerIdx` and `producerIdx` exist by construction.
378       LLVM_DEBUG(dbgs() << "\n"
379                         << LinalgDependenceGraph::getDependenceTypeStr(depType)
380                         << "producer: " << *producer.getOperation() << " view: "
381                         << producedView << " output index: " << producerIdx);
382       (void)producerIdx;
383 
384       // Simple fusability checks.
385       if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
386         continue;
387 
388       return dependence;
389     }
390   }
391   return {};
392 }
393 
394 Optional<FusionInfo>
395 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
396                                    unsigned consumerIdx,
397                                    const LinalgDependenceGraph &graph) {
398   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
399       findFusableProducer(consumer, consumerIdx, graph);
400   if (!fusableDependence)
401     return {};
402 
403   LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
404   // If producer is already in the same block as consumer, we are done.
405   if (consumer.getOperation()->getBlock() ==
406       producerOp.getOperation()->getBlock())
407     return {};
408 
409   Value producerView = fusableDependence->dependentOpView.view;
410   Value consumerView = fusableDependence->indexingView;
411 
412   // Must be a subview or a slice to guarantee there are loops we can fuse
413   // into.
414   auto subView = consumerView.getDefiningOp<SubViewOp>();
415   auto slice = consumerView.getDefiningOp<SliceOp>();
416   if (!subView && !slice) {
417     LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
418     return {};
419   }
420 
421   // Fuse `producer` just before `consumer`.
422   OpBuilder::InsertionGuard g(b);
423   b.setInsertionPoint(consumer.getOperation());
424   ScopedContext scope(b, consumer.getLoc());
425   LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
426   Optional<unsigned> producerIdxOpt =
427       producerOp.getIndexOfOutputBuffer(producerView);
428   assert(producerIdxOpt.hasValue() && "incorrect operand index");
429   unsigned producerIdx = producerIdxOpt.getValue();
430 
431   auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
432   return FusionInfo{producerOp, fusedProducer};
433 }
434 
435 /// Walk back use-def chain through scf::For yields.
436 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
437 static void getProducerOfTensor(Value tensor, LinalgOp &producer,
438                                 unsigned &outputIndex) {
439   if (!tensor.getType().isa<RankedTensorType>())
440     return;
441 
442   while (true) {
443     if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
444       producer = linalgOp;
445       outputIndex = tensor.cast<OpResult>().getResultNumber();
446       return;
447     }
448     if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
449       tensor = subTensorOp.source();
450       continue;
451     }
452     if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
453       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
454         tensor = forOp.getResult(blockArg.getArgNumber());
455         continue;
456       }
457     }
458     return;
459   }
460 }
461 
462 Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
463                                                         LinalgOp consumer,
464                                                         unsigned consumerIdx) {
465   Value inputTensor = consumer.getInput(consumerIdx);
466   LinalgOp producerOp;
467   unsigned producerIdx;
468   getProducerOfTensor(inputTensor, producerOp, producerIdx);
469 
470   // Must be a subtensor to guarantee there are loops we can fuse into.
471   auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
472   if (!subTensor || !producerOp) {
473     LLVM_DEBUG(dbgs() << "\nNot fusable (not a subtensor)");
474     return {};
475   }
476 
477   // If producer is already in the same block as consumer, we are done.
478   if (consumer.getOperation()->getBlock() ==
479       producerOp.getOperation()->getBlock())
480     return {};
481 
482   // Insert fused `producer` just before `consumer`.
483   OpBuilder::InsertionGuard g(b);
484   b.setInsertionPoint(consumer.getOperation());
485   ScopedContext scope(b, consumer.getLoc());
486   LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
487   LinalgOp fusedProducer =
488       fuse(b, producerOp, producerIdx, consumer, consumerIdx);
489 
490   // Replace use.
491   // Canonicalizations are not guaranteed to have happened before constructing
492   // `fusedProducer`. In the tensor case this can result in temporary type
493   // mismatches. Insert a `tensor_cast` op to propagate the transformation
494   // invariant that types are compatible.
495   Value def = fusedProducer.getOperation()->getResult(producerIdx);
496   OpOperand &use = consumer.getOperation()->getOpOperand(consumerIdx);
497   Type consumerType = use.get().getType();
498   if (consumerType != def.getType())
499     def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
500   use.set(def);
501   return FusionInfo{producerOp, fusedProducer};
502 }
503 
504 /// Returns the positions of the loop in `op` that can be tiled based on the
505 /// operations that are to be fused with it. For example, in a
506 ///
507 ///   linalg.matmul ins(%a, %b : ...) outs(%c : ...)
508 ///
509 /// if the producer of %a needs to be fused with this op, only the `i` loop of
510 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
511 /// fused, then no loops can be tiled while fusing.
512 static DenseSet<unsigned> collectTileAndFuseLoops(
513     LinalgOp op, ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem>
514                      fusableDependences) {
515   // 1. Only parallel loops can be used for tile + fuse. Find the number of
516   // common outer parallel loops between the op and its producers being fused.
517   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
518     return linalgOp.iterator_types()
519         .getValue()
520         .take_while([](Attribute attr) -> bool {
521           return attr.cast<StringAttr>().getValue() ==
522                  getParallelIteratorTypeName();
523         })
524         .size();
525   };
526 
527   size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
528   for (auto dependence : fusableDependences) {
529     numOuterParallelLoops =
530         std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast<LinalgOp>(
531                                             dependence.dependentOpView.op)));
532   }
533 
534   // Need to compute what tiled loops can be "fused". Given the precondition
535   // that all indexing map for the producer view is a projected permutation, we
536   // can assert that the producer iterates over the dimensions of the "fused
537   // view" only once. To be used a fused loop the producer should use this loop
538   // to access the fused view. For example, consider
539   //
540   // ```
541   //   linalg.add ins(%a, %b) outs(%c)
542   //   linalg.matmul ins(%d, %c) outs(%e)
543   // ```
544   //
545   // if `linalg.add` has the semantics of `c = a + b`, then the following
546   // tile+fuse code is correct.
547   //
548   // ```
549   // for j ... += TSj
550   //   %sa = subview %a[0, %j][...]
551   //   %sb = subview %b[0, %j][...]
552   //   %sc = subview %c[0, %j][...]
553   //   %sd = subview %d[0, 0][...]
554   //   %se = subview %e[0, %j][...]
555   //   linalg.add ins(%sa, %sb) outs(%sc)
556   //   linalg.matmul ins(%sd, %sc) outs(%se)
557   // ```
558   //
559   // On the other hand tiling along i would be incorrect
560   //
561   // ```
562   // for %i .. += TSi
563   //   %sa = subview %a[%i, 0][...]
564   //   %sb = subview %b[%i, 0][...]
565   //   %sc = subview %c[%i, 0][...]
566   //   %sc2 = subview %c[0, 0][...]
567   //   %sd = subview %d[%i, 0][...]
568   //   %se = subview %e[%i, 0][...]
569   //   linalg.add ins(%sa, %sb) outs(%sc)
570   //   linalg.matmul ins(%sd, %sc2) outs(%se)
571   // ```
572   //
573   // The write to the subview `%sc` in `linalg.add` is performed after the read
574   // from it using `%sc2` violating the RAW dependence of the original code. To
575   // find such loops indexing map of the fused view in the consumer op is
576   // used. For the above example, this indexing map is
577   //
578   //   affine_map<(d0, d1, d2) -> (d2, d1)>
579   //
580   // Since d0 is not in the result expressions of this map, it is not treated as
581   // tile + fuse loop, (but d1 is).
582   //
583   // TODO: The above is probably restrictive and there might be a generalization
584   // of these that might allow for more fusion opportunities. Explore based on
585   // needs.
586   SmallVector<DenseSet<unsigned>, 1> commonTilableLoops;
587   for (auto dependence : fusableDependences) {
588     unsigned consumerIdx =
589         op.getIndexOfShapedOperand(dependence.indexingView).getValue();
590     AffineMap consumerAccess = op.getIndexingMap(consumerIdx);
591     // Previously asserted that the consumerAccess map is a projected
592     // permutation, so all results are known to be AffineDimExprs. To remove
593     // this restriction walk the expression to find which dimensions of the
594     // consumer loop appear in the `consumerAccess`.
595     DenseSet<unsigned> positions;
596     for (auto expr : consumerAccess.getResults())
597       positions.insert(expr.cast<AffineDimExpr>().getPosition());
598     commonTilableLoops.emplace_back(std::move(positions));
599   }
600 
601   // 2. Of the outer parallel loops, only those loops can be tiled + fused as
602   // computed above for all the fused dependences can be used to tile and fuse.
603   DenseSet<unsigned> tilableParallelLoops;
604   for (auto index : llvm::seq<unsigned>(0, numOuterParallelLoops)) {
605     if (llvm::all_of(commonTilableLoops,
606                      [&](const DenseSet<unsigned> &tilableLoops) {
607                        return tilableLoops.count(index);
608                      }))
609       tilableParallelLoops.insert(index);
610   }
611   return tilableParallelLoops;
612 }
613 
614 /// Find all dependences that are to be fusable.
615 static Optional<
616     SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
617 findAllFusableDependences(LinalgOp op,
618                           const LinalgDependenceGraph &dependenceGraph,
619                           const LinalgFusionOptions &fusionOptions) {
620   SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>
621       fusableDependences;
622   for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) {
623     if (fusionOptions.indicesToFuse &&
624         !fusionOptions.indicesToFuse->count(operand.index()))
625       continue;
626     Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
627         fusableDependence =
628             findFusableProducer(op, operand.index(), dependenceGraph);
629     if (!fusableDependence)
630       continue;
631     // Make sure that the indexing map of the view used for fusion in the
632     // producer is a projected permutation.
633     LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
634     Value producerView = fusableDependence->dependentOpView.view;
635     unsigned producerIdx =
636         producerOp.getIndexOfOutputBuffer(producerView).getValue();
637     AffineMap producerMap = producerOp.getOutputIndexingMap(producerIdx);
638     if (!producerMap.isProjectedPermutation()) {
639       op.emitError("unhandled non permutation indexing map for fused view in "
640                    "producer for operand at index ")
641           << operand.index();
642       return llvm::None;
643     }
644     Value consumerView = fusableDependence->indexingView;
645     unsigned consumerIdx = op.getIndexOfShapedOperand(consumerView).getValue();
646     if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) {
647       op.emitError(
648           "unhandled case where indexing map for fused view in the consumer is "
649           "not a projected permuration while fusing at index ")
650           << operand.index();
651       return llvm::None;
652     }
653     fusableDependences.push_back(*fusableDependence);
654     if (!fusionOptions.indicesToFuse)
655       break;
656   }
657   return fusableDependences;
658 }
659 
660 static bool isZero(Value v) {
661   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
662     return cst.getValue() == 0;
663   return false;
664 }
665 
666 template <typename LoopType>
667 static Optional<TiledAndFusedLinalgOps>
668 tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
669                          const LinalgDependenceGraph &dependenceGraph,
670                          const LinalgTilingOptions &tilingOptions,
671                          const LinalgFusionOptions &fusionOptions) {
672   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
673   // Some of the tiling options might not be supportable with tile and fuse.
674   // TODO: Support interchange with tile + fuse.
675   if (!tilingOptions.interchangeVector.empty()) {
676     op.emitError("unable to handle tile and fuse with interchange");
677     return llvm::None;
678   }
679 
680   OpBuilder::InsertionGuard g(rewriter);
681   rewriter.setInsertionPoint(op);
682   ScopedContext scope(rewriter, op.getLoc());
683 
684   // Find all the producers.
685   Optional<SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
686       fusableDependencesOpt =
687           findAllFusableDependences(op, dependenceGraph, fusionOptions);
688   if (!fusableDependencesOpt)
689     return llvm::None;
690   ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependences(
691       *fusableDependencesOpt);
692 
693   // Enforce the convention that "tiling by zero" skips tiling a particular
694   // dimension. This convention is significantly simpler to handle instead of
695   // adjusting affine maps to account for missing dimensions.
696   auto nLoops = op.getNumLoops();
697   SmallVector<Value, 4> tileSizeVector =
698       tilingOptions.tileSizeComputationFunction(rewriter, op);
699   if (tileSizeVector.size() < nLoops) {
700     auto zero = std_constant_index(0);
701     tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
702   }
703 
704   TiledAndFusedLinalgOps ret;
705 
706   // Find the loops that can be tiled and fused.
707   DenseSet<unsigned> tileFuseLoops =
708       collectTileAndFuseLoops(op, fusableDependences);
709 
710   // If there are no fusable dependences or there are no tile+fusable loops,
711   // just return.
712   if (fusableDependences.empty() || tileFuseLoops.empty()) {
713     return llvm::None;
714   }
715 
716   // Get the tile sizes for the first and second tiling steps. For the first
717   // step the tile size are set to zero for the loops that arent
718   // fused. Similarly for the second step, the tile sizes are set to zero for
719   // the loops that are fused. For example, if for the following input
720   //
721   // ```
722   //   linalg.add ins(%a, %b) outs(%c)
723   //   linalg.matmul ins(%d, %c) outs(%e)
724   // ```
725   //
726   // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
727   // respectively, and since only `j` can be tiled and fused. The tile sizes
728   // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
729   // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
730   // the tiled matmul generated by the first tiling step.
731   SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
732   for (auto tileSize : enumerate(tileSizeVector)) {
733     auto zero = std_constant_index(0);
734     if (tileFuseLoops.count(tileSize.index())) {
735       tileAndFuseSizes.push_back(tileSize.value());
736       tileSizes.push_back(zero);
737     } else {
738       tileSizes.push_back(tileSize.value());
739       tileAndFuseSizes.push_back(zero);
740     }
741   }
742 
743   // Tile for the loops that can be fused.
744   LinalgTilingOptions firstTilingOptions = tilingOptions;
745   firstTilingOptions.setTileSizes(tileAndFuseSizes);
746   Optional<TiledLinalgOp> firstTiledOp =
747       tileLinalgOp(rewriter, op, firstTilingOptions);
748   if (!firstTiledOp)
749     return llvm::None;
750   ret.op = firstTiledOp->op;
751   ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
752 
753   rewriter.setInsertionPoint(ret.op);
754   // Fuse the operands.
755   for (auto producer : enumerate(fusableDependences)) {
756     LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op);
757     unsigned producerIdx =
758         producerOp.getIndexOfOutputBuffer(producer.value().dependentOpView.view)
759             .getValue();
760     unsigned consumerIdx =
761         op.getIndexOfShapedOperand(producer.value().indexingView).getValue();
762     LinalgOp fusedOp =
763         fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx);
764     ret.fusedProducers.push_back(fusedOp);
765     ret.originalProducers.push_back(producerOp);
766   }
767 
768   if (!llvm::all_of(tileSizes, isZero)) {
769     // Tile the remaining loops of the root operation.
770     LinalgTilingOptions secondTilingOptions = tilingOptions;
771     // The distribution is done only for the tile+fused loops.
772     secondTilingOptions.distribution = llvm::None;
773     secondTilingOptions.setTileSizes(tileSizes);
774     Optional<TiledLinalgOp> secondTiledOp =
775         tileLinalgOp(rewriter, ret.op, secondTilingOptions);
776     if (!secondTiledOp)
777       return llvm::None;
778     ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
779                             secondTiledOp->loops.end());
780     rewriter.eraseOp(ret.op);
781     ret.op = secondTiledOp->op;
782   }
783 
784   return ret;
785 }
786 
787 Optional<TiledAndFusedLinalgOps>
788 mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
789                                    const LinalgDependenceGraph &dependenceGraph,
790                                    const LinalgTilingOptions &tilingOptions,
791                                    const LinalgFusionOptions &fusionOptions) {
792   switch (tilingOptions.loopType) {
793   case LinalgTilingLoopType::Loops:
794     return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
795                                                 tilingOptions, fusionOptions);
796   case LinalgTilingLoopType::ParallelLoops:
797     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
798         rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
799   default:;
800   }
801   return llvm::None;
802 }
803