xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (revision 11ea2e2448a5f071a04463f94b7bccfe0a32d264)
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/ADT/MapVector.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 
31 #include <set>
32 
33 #define DEBUG_TYPE "linalg-fusion"
34 
35 using namespace mlir;
36 using namespace mlir::edsc;
37 using namespace mlir::edsc::intrinsics;
38 using namespace mlir::linalg;
39 
40 using llvm::dbgs;
41 
42 /// Implements a simple high-level fusion pass on linalg structured operations.
43 ///
44 /// In each block, linalg ops are processed in reverse textual order.
45 /// Given a linalg op `O`, fusion occurs by:
46 ///   1. inspecting the linalg ops that write into the views read by `O`. There
47 ///      are 2 cases:
48 ///      a) buffer case: use the SSA value of the views and a simple alias
49 ///         analysis on subview ops to determine producer-consumer dependences;
50 ///      b) tensor case: use SSA use-def chains on subtensor ops;
51 ///   2. greedily fuse the linalg ops that produce the subview/subtensor.
52 ///   3. inspect the fused ops and determine whether they have other remaining
53 ///      LinalgOp uses. If not, then erase the original producing linalg op.
54 ///
55 /// More advanced use cases, analyses as well as profitability heuristics are
56 /// left for future work.
57 
58 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
59 // by `permutationMap`.
60 static void inferShapeComponents(AffineMap permutationMap,
61                                  ArrayRef<Range> loopRanges,
62                                  SmallVectorImpl<Value> &offsets,
63                                  SmallVectorImpl<Value> &sizes,
64                                  SmallVectorImpl<Value> &strides) {
65   assert(permutationMap.isProjectedPermutation() &&
66          "expected some subset of a permutation map");
67   SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
68   unsigned idx = 0;
69   for (AffineExpr e : permutationMap.getResults()) {
70     // loopToOperandRangesMaps are permutations-only, just swap indices.
71     unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
72     shapeRanges[idx++] = loopRanges[loopPos];
73   }
74   // Construct a new subshape for the tile.
75   unsigned rank = shapeRanges.size();
76   offsets.reserve(rank);
77   sizes.reserve(rank);
78   strides.reserve(rank);
79   for (auto r : shapeRanges) {
80     offsets.push_back(r.offset);
81     sizes.push_back(r.size);
82     strides.push_back(r.stride);
83   }
84 }
85 
86 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
87 // a subset of the original loop ranges of `op`.
88 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
89 // to the `loopRanges` in order to obtain view ranges.
90 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
91                                     ArrayRef<Range> loopRanges) {
92   SmallVector<Value, 8> clonedShapes;
93   clonedShapes.reserve(op.getNumShapedOperands());
94 
95   // Iterate over the shape operands in order.
96   // Extract the subranges from the linearized ranges.
97   for (auto en : llvm::enumerate(op.getShapedOperands())) {
98     unsigned shapedOperandIdx = en.index();
99     AffineMap map = op.getIndexingMap(shapedOperandIdx);
100     LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
101                             << " with indexingMap: " << map << "\n");
102     SmallVector<Value, 4> offsets, sizes, strides;
103     inferShapeComponents(map, loopRanges, offsets, sizes, strides);
104     Value shape = en.value();
105     Value sub = shape.getType().isa<MemRefType>()
106                     ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides)
107                           .getResult()
108                     : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
109                           .getResult();
110     clonedShapes.push_back(sub);
111   }
112   // Append the other operands.
113   auto operands = op.getAssumedNonShapedOperands();
114   clonedShapes.append(operands.begin(), operands.end());
115 
116   // Iterate over the results in order.
117   // Extract the subtensor type from the linearized range.
118   // Since we do not enforce any canonicalizations on the fly, this is always
119   // fully dynamic at construction time.
120   SmallVector<Type, 4> resultTypes;
121   resultTypes.reserve(op.getOperation()->getNumResults());
122   for (RankedTensorType t : op.getOutputTensorTypes()) {
123     unsigned rank = t.getRank();
124     SmallVector<int64_t, 4> staticOffsetsVector(
125         rank, ShapedType::kDynamicStrideOrOffset);
126     SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
127     SmallVector<int64_t, 4> staticStridesVector(
128         rank, ShapedType::kDynamicStrideOrOffset);
129     resultTypes.push_back(SubTensorOp::inferResultType(
130         t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
131         staticStridesVector));
132   }
133 
134   Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
135   // When the producer is an IndexedGenericOp, we have to transform its block
136   // IV arguments according to the tiling of the consumer, i.e. offset them by
137   // the values computed in `loopRanges`.
138   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
139     auto &block = indexedGenericOp.region().front();
140     OpBuilder::InsertionGuard g(b);
141     b.setInsertionPointToStart(&block);
142     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
143       Value oldIndex = block.getArgument(i);
144       // TODO: replace by an affine_apply.
145       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
146                                          loopRanges[i].offset);
147       oldIndex.replaceAllUsesExcept(newIndex,
148                                     SmallPtrSet<Operation *, 1>{newIndex});
149     }
150   }
151 
152   return clonedOp;
153 }
154 
155 struct ShapeDimension {
156   Value shape;
157   unsigned dimension;
158 };
159 
160 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
161 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
162 // guarantees at least one such dimension is found. If multiple candidates exist
163 // they must agree by construction (i.e. have the same size) and we just return
164 // the first one.
165 static ShapeDimension
166 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
167                           bool fromSubViewOpOnly = false) {
168   auto maps = op.indexing_maps();
169   // Iterate over the inputs and outputs in order.
170   // Extract the subranges from the linearized ranges.
171   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
172   for (auto en : llvm::enumerate(ios)) {
173     // The method `getRangeFromOperandShape` requires using SubViewOp or
174     // SubTensorOps. If the value isnt defined from there continue.
175     // todo: The method should be adapted to get the values from
176     // `ViewInterface`. The interface needs a `getOrCreateRanges` method which
177     // currently returns a `linalg.range`. The fix here is to move this op to
178     // `std` dialect and add the method to `ViewInterface`.
179     if (fromSubViewOpOnly &&
180         !isa_and_nonnull<SubViewOp, SubTensorOp>(en.value().getDefiningOp()))
181       continue;
182 
183     unsigned idx = en.index();
184     auto map = maps[idx].cast<AffineMapAttr>().getValue();
185     LLVM_DEBUG(llvm::dbgs()
186                << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
187     LLVM_DEBUG(llvm::dbgs()
188                << "getShapeDefiningLoopRange map: " << map << "\n");
189     Value shape = en.value();
190     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
191     for (auto en2 : llvm::enumerate(map.getResults())) {
192       auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
193       if (!dimExpr)
194         continue;
195       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
196         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
197                                 << loopDepth << "\n");
198         LLVM_DEBUG(llvm::dbgs()
199                    << "getShapeDefiningLoopRange shape: " << shape << "\n");
200         return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
201       }
202     }
203   }
204   llvm_unreachable("Expect to be able to extract a shape defining loop range");
205 }
206 
207 /// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
208 /// provides the loop range information for the fused loops. The rest are
209 /// obtained from the producer itself, since they are not tiled + fused.
210 static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
211                      const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
212 
213   unsigned nPar = producer.getNumParallelLoops();
214   unsigned nRed = producer.getNumReductionLoops();
215   unsigned nWin = producer.getNumWindowLoops();
216   SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
217   for (auto fusedLoops : fusedLoopsAndRanges)
218     loopRanges[fusedLoops.first] = fusedLoops.second;
219 
220   // Iterate over all dimensions. For the dimensions not identified by the
221   // producer map for `producerIdx`, we need to explicitly compute the shape
222   // that defines the loop ranges using the `producer`.
223   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
224     if (loopRanges[i].offset)
225       LLVM_DEBUG(llvm::dbgs()
226                  << "existing LoopRange: " << loopRanges[i] << "\n");
227     else {
228       auto shapeDim = getShapeDefiningLoopRange(producer, i);
229       loopRanges[i] = Range{std_constant_index(0),
230                             std_dim(shapeDim.shape, shapeDim.dimension),
231                             std_constant_index(1)};
232       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
233     }
234   }
235 
236   return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
237 }
238 
239 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
240 /// expected to be defined by a subview op or a subtensor op.
241 static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
242                                       Value shapedOperand, unsigned dim) {
243   Operation *shapeProducingOp = shapedOperand.getDefiningOp();
244   if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
245     return subViewOp.getOrCreateRanges(b, loc)[dim];
246   if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
247     return subTensorOp.getOrCreateRanges(b, loc)[dim];
248   llvm_unreachable("SubviewOp or SubTensorOp expected");
249 }
250 
251 /// Fuses the producer of `producerIdx` into the loop immediately enclosing
252 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it
253 /// is needed just before the `consumer.
254 ///
255 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
256 /// 2 cases:
257 ///   1. Buffer case: `producerIdx` is the index of the buffer in
258 ///      `producer.getOutputBuffers()`.
259 ///   2. Tensor case: `producerIdx` is the index of the tensor in
260 ///      `producer.getResults()`.
261 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
262                      LinalgOp consumer, unsigned consumerIdx) {
263   AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
264   LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
265                           << ", producer map: " << producerMap << "\n");
266   DenseMap<unsigned, Range> fusedLoopsAndRanges;
267   Location loc = consumer.getLoc();
268   Value shapedOperand = consumer.getShapedOperand(consumerIdx);
269   for (auto en : llvm::enumerate(producerMap.getResults())) {
270     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
271     fusedLoopsAndRanges[posInProducerLoop] =
272         getRangeFromOperandShape(b, loc, shapedOperand, en.index());
273   }
274   return fuse(b, producer, fusedLoopsAndRanges);
275 }
276 
277 // Encode structural fusion safety preconditions.
278 // Some of these will be lifted in the future with better analysis.
279 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
280                                           LinalgOp consumer) {
281   assert(producer.hasBufferSemantics() &&
282          "expected linalg op with buffer semantics");
283   assert(consumer.hasBufferSemantics() &&
284          "expected linalg op with buffer semantics");
285   if (producer.getNumOutputs() != 1) {
286     LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
287     return false;
288   }
289   // Only fuse when the producer block dominates.
290   DominanceInfo dom(producer.getOperation());
291   if (!dom.dominates(producer.getOperation()->getBlock(),
292                      consumer.getOperation()->getBlock())) {
293     LLVM_DEBUG(
294         llvm::dbgs()
295         << "\nNot structurally fusable (producer block does not dominate)");
296     return false;
297   }
298   return true;
299 }
300 
301 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
302                                              LinalgOp consumer,
303                                              Value consumedView,
304                                              LinalgOp producer) {
305   assert(producer.hasBufferSemantics() &&
306          "expected linalg op with buffer semantics");
307   assert(consumer.hasBufferSemantics() &&
308          "expected linalg op with buffer semantics");
309   // Make some simple structural checks that alleviate the need for more
310   // complex analyses.
311   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
312     LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
313                             << *producer.getOperation());
314     return false;
315   }
316   // Check for any interleaved write to consumedView.
317   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
318     LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
319                             << *producer.getOperation());
320     return false;
321   }
322   return true;
323 }
324 
325 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
326                                  LinalgOp consumer, Value consumedView,
327                                  LinalgOp producer) {
328   assert(producer.hasBufferSemantics() &&
329          "expected linalg op with buffer semantics");
330   assert(consumer.hasBufferSemantics() &&
331          "expected linalg op with buffer semantics");
332   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
333     return false;
334   // Check for any fusion-preventing dependence to any shape read/written that
335   // would violate dependences.
336   if (!graph.findCoveringDependences(producer, consumer).empty()) {
337     LLVM_DEBUG(llvm::dbgs()
338                << "\n***Not fusable due to an interleaved dependence:\t"
339                << *producer.getOperation());
340     return false;
341   }
342   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
343     // TODO: add a level of indirection to linalg.generic.
344     if (convOp.padding())
345       return false;
346   }
347   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
348     // TODO: add a level of indirection to linalg.generic.
349     if (convOp.padding())
350       return false;
351   }
352   return true;
353 }
354 
355 static bool isSameSubView(Value a, Value b) {
356   if (a == b)
357     return true;
358   auto sva = a.getDefiningOp<SubViewOp>();
359   auto svb = b.getDefiningOp<SubViewOp>();
360   if (!sva || !svb)
361     return false;
362   if (!isSameSubView(sva.getViewSource(), svb.getViewSource()))
363     return false;
364   if (sva.getType() != svb.getType())
365     return false;
366   if (sva.getNumOperands() != svb.getNumOperands())
367     return false;
368   if (sva.static_offsets() != svb.static_offsets())
369     return false;
370   if (sva.static_sizes() != svb.static_sizes())
371     return false;
372   if (sva.static_strides() != svb.static_strides())
373     return false;
374   /// Skip the "source" operand.
375   for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
376     if (sva.getOperand(idx) != svb.getOperand(idx))
377       return false;
378   return true;
379 }
380 
381 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
382 findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
383                     const LinalgDependenceGraph &dependenceGraph) {
384   // Only consider RAW and WAW atm.
385   for (auto depType : {
386            LinalgDependenceGraph::DependenceType::RAW,
387            LinalgDependenceGraph::DependenceType::WAW,
388        }) {
389     for (auto dependence : llvm::make_filter_range(
390              dependenceGraph.getDependencesInto(consumer, depType),
391              [consumerIdx](
392                  LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
393                return elem.indexingOpView.operandIndex == consumerIdx;
394              })) {
395       auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
396 
397       // Check that the dependence is indeed on the input `consumerIdx` view.
398       auto consumedView =
399           consumer.getBuffer(dependence.indexingOpView.operandIndex);
400       if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
401         continue;
402 
403       // Consumer consumes this view, `isStructurallyFusableProducer` also
404       // checks whether it is a strict subview of the producer view.
405       auto producedView =
406           producer.getBuffer(dependence.dependentOpView.operandIndex);
407       LLVM_DEBUG(llvm::dbgs()
408                  << "\n"
409                  << LinalgDependenceGraph::getDependenceTypeStr(depType)
410                  << "producer: " << *producer.getOperation()
411                  << " view: " << producedView << " output index: "
412                  << dependence.dependentOpView.operandIndex -
413                         producer.getNumInputs()
414                  << "\n");
415       (void)producedView;
416 
417       // Simple fusability checks.
418       if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
419         continue;
420 
421       return dependence;
422     }
423   }
424   return {};
425 }
426 
427 Optional<FusionInfo>
428 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
429                                    unsigned consumerIdx,
430                                    const LinalgDependenceGraph &graph) {
431   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
432       findFusableProducer(consumer, consumerIdx, graph);
433   if (!fusableDependence)
434     return {};
435 
436   LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
437   // If producer is already in the same block as consumer, we are done.
438   if (consumer.getOperation()->getBlock() ==
439       producerOp.getOperation()->getBlock())
440     return {};
441 
442   unsigned producerIdx = fusableDependence->dependentOpView.operandIndex -
443                          producerOp.getNumInputs();
444   Value consumerView = consumer.getShapedOperand(consumerIdx);
445 
446   // Must be a subview or a slice to guarantee there are loops we can fuse
447   // into.
448   auto subView = consumerView.getDefiningOp<SubViewOp>();
449   auto slice = consumerView.getDefiningOp<SliceOp>();
450   if (!subView && !slice) {
451     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
452     return {};
453   }
454 
455   // Fuse `producer` just before `consumer`.
456   OpBuilder::InsertionGuard g(b);
457   b.setInsertionPoint(consumer.getOperation());
458   ScopedContext scope(b, consumer.getLoc());
459   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
460 
461   auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
462   return FusionInfo{producerOp, fusedProducer};
463 }
464 
465 /// Walk back use-def chain through scf::For yields.
466 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
467 static void getProducerOfTensor(Value tensor, LinalgOp &producer,
468                                 unsigned &outputIndex) {
469   if (!tensor.getType().isa<RankedTensorType>())
470     return;
471 
472   while (true) {
473     if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
474       producer = linalgOp;
475       outputIndex = tensor.cast<OpResult>().getResultNumber();
476       return;
477     }
478     if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
479       tensor = subTensorOp.source();
480       continue;
481     }
482     if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
483       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
484         tensor = forOp.getResult(blockArg.getArgNumber());
485         continue;
486       }
487     }
488     return;
489   }
490 }
491 
492 Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
493                                                         LinalgOp consumer,
494                                                         unsigned consumerIdx) {
495   Value inputTensor = consumer.getInput(consumerIdx);
496   LinalgOp producerOp;
497   unsigned producerIdx;
498   getProducerOfTensor(inputTensor, producerOp, producerIdx);
499 
500   // Must be a subtensor to guarantee there are loops we can fuse into.
501   auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
502   if (!subTensor || !producerOp) {
503     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)");
504     return {};
505   }
506 
507   // If producer is already in the same block as consumer, we are done.
508   if (consumer.getOperation()->getBlock() ==
509       producerOp.getOperation()->getBlock())
510     return {};
511 
512   // Insert fused `producer` just before `consumer`.
513   OpBuilder::InsertionGuard g(b);
514   b.setInsertionPoint(consumer.getOperation());
515   ScopedContext scope(b, consumer.getLoc());
516   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
517   LinalgOp fusedProducer =
518       fuse(b, producerOp, producerIdx, consumer, consumerIdx);
519 
520   // Replace use.
521   // Canonicalizations are not guaranteed to have happened before constructing
522   // `fusedProducer`. In the tensor case this can result in temporary type
523   // mismatches. Insert a `tensor_cast` op to propagate the transformation
524   // invariant that types are compatible.
525   Value def = fusedProducer.getOperation()->getResult(producerIdx);
526   OpOperand &use = consumer.getOperation()->getOpOperand(consumerIdx);
527   Type consumerType = use.get().getType();
528   if (consumerType != def.getType())
529     def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
530   use.set(def);
531   return FusionInfo{producerOp, fusedProducer};
532 }
533 
534 /// Prune all dimensions that are of reduction iterator type from `map`.
535 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
536                                            AffineMap map) {
537   SmallVector<unsigned, 2> projectedDims;
538   for (auto attr : llvm::enumerate(iteratorTypes)) {
539     if (!isParallelIterator(attr.value()))
540       projectedDims.push_back(attr.index());
541   }
542   return getProjectedMap(map, projectedDims);
543 }
544 
545 /// Returns the mapping from iterations in the consumer that write to the same
546 /// location as the iterations in the producer. To do so use
547 /// - indexing map of the fused view in the consumer : consumerIndexMap
548 /// - indexing map of the fused view in the producer : producerIndexMap
549 ///     consumerLoopToProducerLoop =
550 ///       inverse(producerIndexMap).compose(consumerIndexMap)
551 static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
552     LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
553   auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
554   AffineMap producerIndexingMap =
555       producer.getIndexingMap(dependence.dependentOpView.operandIndex);
556   auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
557   AffineMap consumerIndexingMap =
558       consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
559 
560   AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
561       producer.iterator_types().getValue(), producerIndexingMap);
562   if (!prunedProducerIndexingMap.isPermutation())
563     return None;
564 
565   if (consumerIndexingMap.getNumResults() !=
566       prunedProducerIndexingMap.getNumResults())
567     return None;
568 
569   LLVM_DEBUG({
570     llvm::dbgs() << "\t producerMap : ";
571     producerIndexingMap.print(llvm::dbgs());
572     llvm::dbgs() << "  pruned : ";
573     prunedProducerIndexingMap.print(llvm::dbgs());
574     llvm::dbgs() << "\n";
575     llvm::dbgs() << "\t consumerMap : ";
576     consumerIndexingMap.print(llvm::dbgs());
577     llvm::dbgs() << "\n";
578   });
579 
580   AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
581   if (!invProducerIndexMap)
582     return None;
583 
584   return invProducerIndexMap.compose(consumerIndexingMap);
585 }
586 
587 /// Given a projected permutation `map`, returns true if the map changes the
588 /// order in which the fused loop dimension appear.
589 static bool doesTransposeAccess(AffineMap map,
590                                 const std::set<unsigned> &fusableLoops) {
591   Optional<unsigned> lastFusableLoop;
592   for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
593          return expr.cast<AffineDimExpr>().getPosition();
594        })) {
595     if (!fusableLoops.count(pos))
596       continue;
597     if (!lastFusableLoop) {
598       lastFusableLoop = pos;
599       continue;
600     }
601     if (pos <= lastFusableLoop.getValue())
602       return true;
603     lastFusableLoop = pos;
604   }
605   return false;
606 }
607 
608 /// Returns the positions of the loop in `op` that can be tiled based on the
609 /// operations that are to be fused with it. For example, in a
610 ///
611 ///   linalg.matmul ins(%a, %b : ...) outs(%c : ...)
612 ///
613 /// if the producer of %a needs to be fused with this op, only the `i` loop of
614 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
615 /// fused, then no loops can be tiled while fusing. The conditions used are:
616 /// 1. Only parallel loops can be used for tile + fuse. Find the number of
617 ///    common outer parallel loops between the op and its producers being fused.
618 /// 2. Of the parallel loops only some can be fused. Only those loops can be
619 ///    fused such where the fusable loops iteration space only touches one tile
620 ///    of the fused operation. This is because the producer (which is writing
621 ///    the fused subview) has update semantics.
622 ///
623 /// Since an inverse computation is needed, we need to consider the projection
624 /// of the producerIndexMap w.r.t the parallel loops.  The actual fusable loops
625 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to
626 /// parallel loops and appear in the result of the map
627 ///
628 /// Example 1:
629 ///   linalg.fill(%c, %cst)
630 ///   linalg.matmul ins(%a, %b) outs(%c)
631 ///     Number of parallel loops : 2
632 ///     producerIndexMap = affine_map<(i, j) ->(i , j)>
633 ///     consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
634 ///     consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
635 ///     Fused dimensions : i, j
636 ///
637 /// Example 2:
638 ///   linalg.matmul ins(%a, %b) outs(%c)
639 ///   linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
640 ///                   iterator_types = ["parallel", "parallel"]}
641 ///     ins(%c) ...
642 ///
643 ///     Number of parallel loops = 2:
644 ///     producerIndexMap (projected to parallel loops) =
645 ///       affine_map<(i, j) -> (i, j)>
646 ///     consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
647 ///     Fused dimensions : i, j
648 ///
649 /// Example 3:
650 ///   linalg.copy(%s, %b)
651 ///   linalg.matmul ins(%a, %b) outs(%c)
652 ///
653 ///   Number of parallel loops = 2
654 ///   produceIndexMap : affine_map<(i, j) -> (i, j)>
655 ///   consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
656 ///     submap with only parallel loops = affine_map<(i, j) -> (j)>
657 ///   Fused dimensions : j
658 static std::set<unsigned>
659 collectFusableLoops(ArrayRef<LinalgOp> ops,
660                     const FusableOpDependencesTy &fusableDependences) {
661   assert(!ops.empty());
662   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
663     return linalgOp.iterator_types()
664         .getValue()
665         .take_while([](Attribute attr) -> bool {
666           return attr.cast<StringAttr>().getValue() ==
667                  getParallelIteratorTypeName();
668         })
669         .size();
670   };
671 
672   size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
673   for (auto op : ops.drop_back()) {
674     numOuterParallelLoops =
675         std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
676   }
677 
678   std::set<unsigned> fusableLoops;
679   auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
680   fusableLoops.insert(range.begin(), range.end());
681 
682   for (auto op : reverse(ops)) {
683     for (auto dependence : fusableDependences.lookup(op)) {
684       LLVM_DEBUG({
685         llvm::dbgs() << "\t fusable :";
686         for (unsigned i : fusableLoops)
687           llvm::dbgs() << " " << i;
688         llvm::dbgs() << "\n";
689       });
690 
691       Optional<AffineMap> consumerLoopToProducerLoop =
692           getConsumerLoopToProducerLoopMap(dependence);
693       if (!consumerLoopToProducerLoop) {
694         op.emitRemark("failed to get map from consumer loop to producer loop");
695         return {};
696       }
697       // todo: This condition is only an implementation limitation. When fusing
698       // the operation, if the accesses in the producer/consumer are transposes
699       // of each other, the loop bounds for the tiled producer can be
700       // manipulated accordingly. This requires some additional bookkeeping in
701       // the implementation of tile+fuse that is defered to later.
702       if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
703         op.emitRemark("unhandled fusion when fusion requires permutation");
704         return {};
705       }
706 
707       std::set<unsigned> candidates;
708       for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
709         unsigned position = expr.cast<AffineDimExpr>().getPosition();
710         if (fusableLoops.count(position))
711           candidates.insert(position);
712       }
713       LLVM_DEBUG({
714         llvm::dbgs() << "\t candidates :";
715         for (unsigned i : candidates)
716           llvm::dbgs() << " " << i;
717         llvm::dbgs() << "\n";
718       });
719       if (candidates.empty())
720         return {};
721       std::swap(candidates, fusableLoops);
722     }
723   }
724 
725   return fusableLoops;
726 }
727 
728 /// Find all dependences that are fusable.
729 FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
730     ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
731   FusableOpDependencesTy fusableDependences;
732   // TODO: Currently fusion would not be legal if the fusable dependence is to
733   // the same producer but different indexing map in the consumer. Fix this, but
734   // in the meanwhile disallow such a fusion.
735   DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
736   for (LinalgOp op : reverse(ops)) {
737     for (auto operandIndex :
738          llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
739       Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
740           fusableDependence =
741               findFusableProducer(op, operandIndex, dependenceGraph);
742       if (!fusableDependence)
743         continue;
744       LinalgOp producerOp =
745           cast<LinalgOp>(fusableDependence->dependentOpView.op);
746       // Do not fuse dependences that are to operations not in the same basic
747       // block. This avoid moving fused operations across loops that might
748       // themselves carry dependency making the fusion illegal.
749       if (producerOp.getOperation()->getBlock() !=
750           op.getOperation()->getBlock()) {
751         op.emitRemark("unhandled fusion of ops in different basic blocks");
752         return FusableOpDependencesTy{};
753       }
754       // Make sure that the indexing map of the view used for fusion in the
755       // producer is a projected permutation.
756       unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
757       AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
758       if (!producerMap.isProjectedPermutation()) {
759         op.emitRemark(
760             "unhandled non permutation indexing map for fused view in "
761             "producer for operand at index ")
762             << operandIndex;
763         return FusableOpDependencesTy{};
764       }
765 
766       unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
767       AffineMap consumerMap = op.getIndexingMap(consumerIdx);
768       if (!consumerMap.isProjectedPermutation()) {
769         op.emitRemark(
770             "unhandled case where indexing map for fused view in the consumer "
771             "is "
772             "not a projected permuration while fusing at index ")
773             << operandIndex;
774         return FusableOpDependencesTy{};
775       }
776 
777       // Check if the producer is already a fusion candidate. Cannot fuse this
778       // dependence if it has a different indexing map when used in the
779       // consumer.
780       if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
781           fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
782         op.emitRemark(
783             "unhandled fusion to the same producer but with different "
784             "indexing maps");
785         return FusableOpDependencesTy{};
786       }
787       fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
788 
789       fusableDependences[producerOp.getOperation()].push_back(
790           *fusableDependence);
791     }
792   }
793   return fusableDependences;
794 }
795 
796 static bool isZero(Value v) {
797   if (auto cst = v.getDefiningOp<ConstantIndexOp>())
798     return cst.getValue() == 0;
799   return false;
800 }
801 
802 /// Tile the fused loops in the root operation, by setting the tile sizes for
803 /// all other loops to zero (those will be tiled later).
804 static Optional<TiledLinalgOp> tileRootOperation(
805     OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
806     const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
807   SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
808   auto zero = std_constant_index(0);
809   for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
810     if (!fusedLoops.count(i))
811       tileSizes[i] = zero;
812   LinalgTilingOptions tileFusedLoopsOptions = options;
813   tileFusedLoopsOptions.setTileSizes(tileSizes);
814   return tileLinalgOp(builder, op, tileFusedLoopsOptions);
815 }
816 
817 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
818 /// to be a tiled operation such that it is valid to fuse all operations in
819 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
820 /// `tiledOp`.
821 static SmallVector<LinalgOp, 1>
822 fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
823                ArrayRef<LinalgOp> fusionCandidates,
824                const FusableOpDependencesTy &fusableDependences,
825                const std::set<unsigned> &fusedLoops) {
826   OpBuilder::InsertionGuard guard(builder);
827   builder.setInsertionPoint(tiledOp);
828   DenseMap<unsigned, Range> fusedLoopsAndRanges;
829   for (unsigned loop : fusedLoops) {
830     ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
831     fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
832         builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
833   }
834 
835   SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
836   for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
837     LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
838     fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
839     builder.setInsertionPoint(fusedOp);
840   }
841   return fusedOps;
842 }
843 
844 template <typename LoopType>
845 static Optional<TiledAndFusedLinalgOps>
846 tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
847                          const LinalgDependenceGraph &dependenceGraph,
848                          const LinalgTilingOptions &tilingOptions) {
849   if (ops.empty())
850     return llvm::None;
851   LinalgOp rootOp = ops.back();
852   for (auto op : enumerate(ops)) {
853     // TODO: Nothing in the fusion of sequence of ops is specific to
854     // buffers. This check can be removed after it is tested on tensors.
855     LinalgOp linalgOp = op.value();
856     if (!linalgOp.hasBufferSemantics()) {
857       linalgOp.emitError("tile and fuse only tested for buffer operation");
858       return llvm::None;
859     }
860   }
861   // TODO: Support interchange with tile + fuse. This might actually help do
862   // better fusion.
863   if (!tilingOptions.interchangeVector.empty()) {
864     rootOp.emitError("unable to handle tile and fuse with interchange");
865     return llvm::None;
866   }
867 
868   OpBuilder::InsertionGuard guard(builder);
869   builder.setInsertionPoint(rootOp);
870   ScopedContext scope(builder, rootOp.getLoc());
871 
872   // Find all the producers.
873   FusableOpDependencesTy fusableDependences =
874       findAllFusableDependences(ops, dependenceGraph);
875   if (fusableDependences.empty())
876     return llvm::None;
877 
878   TiledAndFusedLinalgOps ret;
879   // Find the loops that can be tiled and fused.
880   ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
881 
882   // If there are no fusable dependences or there are no tile+fusable loops,
883   // just return.
884   if (ret.fusedLoopDims.empty()) {
885     return llvm::None;
886   }
887 
888   // Tile the fused loops in the last operation in the list.
889   SmallVector<Value, 4> tileSizeVector =
890       tilingOptions.tileSizeComputationFunction(builder, rootOp);
891   Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
892       builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
893   if (!tiledRootOp) {
894     rootOp.emitError("failed to tile the fused loops");
895     return llvm::None;
896   }
897   ret.op = tiledRootOp->op;
898   ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
899 
900   // Fuse the other operations into the fused inter-tile loops produced above.
901   ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
902                                       fusableDependences, ret.fusedLoopDims);
903   return ret;
904 }
905 
906 Optional<TiledAndFusedLinalgOps>
907 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
908                                    const LinalgDependenceGraph &dependenceGraph,
909                                    const LinalgTilingOptions &tilingOptions) {
910   switch (tilingOptions.loopType) {
911   case LinalgTilingLoopType::Loops:
912     return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
913                                                 tilingOptions);
914   case LinalgTilingLoopType::ParallelLoops:
915     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
916         builder, ops, dependenceGraph, tilingOptions);
917   default:;
918   }
919   return llvm::None;
920 }
921