xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (revision 1f5335c1db5d54b4465677c224b48e0ffc78e6d9)
1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Passes.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
21 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include <optional>
29 #include <utility>
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
33 #include "mlir/Dialect/Linalg/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 using namespace mlir::linalg;
38 
39 //===---------------------------------------------------------------------===//
40 // Methods and patterns that fuse elementwise `linalg.generic` operations.
41 //===---------------------------------------------------------------------===//
42 
43 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
44 /// the `producer` to use in the fused operation given the indexing map of the
45 /// result of the producer in the consumer.
46 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
47     OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
48     AffineMap fusedConsumerArgIndexMap) {
49   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
50   // from consumer loop -> consumer arg tensor index/producer result tensor
51   // index. The fused loop is same as the consumer loop. For each producer arg
52   // the indexing map to be computed is a map from consumer loop -> producer
53   // arg tensor index.
54   // producerResultIndexMap is a map from producer loop -> tensor index.
55   // Compute the inverse to get map from tensor index -> producer loop.
56   // The inverse is a map from producer result tensor index -> producer loop.
57   AffineMap invProducerResultIndexMap =
58       inversePermutation(producerResultIndexMap);
59   assert(invProducerResultIndexMap &&
60          "expected producer result indexing map to be invertible");
61 
62   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
63   // argMap is a map from producer loop -> producer arg tensor index.
64   AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
65 
66   // Compose argMap with invProducerResultIndexMap to get a map from
67   // producer result tensor index -> producer arg tensor index.
68   AffineMap t1 = argMap.compose(invProducerResultIndexMap);
69 
70   // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
71   // consumer loop/ fused loop -> producer arg tensor index.
72   return t1.compose(fusedConsumerArgIndexMap);
73 }
74 
75 // Checks if the given operand can be dropped, and the remaining operands
76 // of the fused producer & consumer after the fusion can still compute the
77 // bounds of the op.
78 static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
79     GenericOp producer, GenericOp consumer,
80     ArrayRef<OpOperand *> opOperandsToIgnore) {
81   SmallVector<AffineMap> indexingMaps;
82 
83   SmallVector<GenericOp> ops = {producer, consumer};
84   for (auto &op : ops) {
85     for (auto &opOperand : op->getOpOperands()) {
86       if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
87         continue;
88       }
89       indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
90     }
91   }
92   if (indexingMaps.empty()) {
93     // If there are no indexing maps, the operand can only be dropped
94     // if neither op has loops.
95     return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
96   }
97 
98   // The concatanation of the remained indexing maps must be invertible, so
99   // the bounds of the op can be still computed after dropping the selected
100   // operand. inversePermutation returns an empty AffineMap in case the
101   // concatanated indexing maps are not invertible.
102   return inversePermutation(concatAffineMaps(
103              indexingMaps, producer.getContext())) != AffineMap();
104 }
105 
106 /// Returns a set of indices of the producer's results which would
107 /// be preserved after the fusion.
108 /// * There is a chance that the implementation of the transformation does not
109 /// agree with the result of this method. This function gives a prediction based
110 /// on an optimized fusion.
111 llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
112     GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
113   llvm::SmallDenseSet<int> preservedProducerResults;
114   llvm::SmallVector<OpOperand *> opOperandsToIgnore;
115 
116   // The fusedOperand will be removed during the fusion
117   opOperandsToIgnore.emplace_back(fusedOperand);
118 
119   for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
120     auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
121     opOperandsToIgnore.emplace_back(outputOperand);
122     if (producer.payloadUsesValueFromOperand(outputOperand) ||
123         !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
124                                                   opOperandsToIgnore) ||
125         llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
126           return user != consumer.getOperation();
127         })) {
128       preservedProducerResults.insert(producerResult.index());
129 
130       // In case the operand can't be dropped
131       (void)opOperandsToIgnore.pop_back_val();
132     }
133   }
134   return preservedProducerResults;
135 }
136 
137 /// Conditions for elementwise fusion of generic operations.
138 bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
139   if (!fusedOperand)
140     return false;
141 
142   auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
143   auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
144 
145   // Check producer and consumer are generic ops.
146   if (!producer || !consumer)
147     return false;
148 
149   // Consumer can have mixed semantics, just check operand itself has tensor
150   // type. Producer must have full tensor semantics to avoid potential
151   // aliasing between producer and consumer memrefs.
152   if (!producer.hasPureTensorSemantics() ||
153       !isa<RankedTensorType>(fusedOperand->get().getType()))
154     return false;
155 
156   // Verify that
157   // - the producer has all "parallel" iterator type.
158   if (producer.getNumParallelLoops() != producer.getNumLoops())
159     return false;
160 
161   // Only allow fusing the producer of an input operand for now.
162   // TODO: allow fusing the producer of an output operand.
163   if (!consumer.isDpsInput(fusedOperand))
164     return false;
165 
166   // Get the consumer index map. The number of results of the consumer index
167   // map must match the number of loops of the producer.
168   AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
169   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
170     return false;
171 
172   // Finally the index_map for the result must be invertible. For now just
173   // verify it is a permutation.
174   AffineMap producerResultIndexMap =
175       producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
176   if (!producerResultIndexMap.isPermutation())
177     return false;
178 
179   // Ensure that the fusion does not remove size information required to
180   // get the loop bounds. For non-reduction generics, this is trivially the
181   // case due to the output operand. For reductions, we need to check that after
182   // the fusion, each loop dimension has at least one input that defines it.
183   if ((consumer.getNumReductionLoops())) {
184     BitVector coveredDims(consumer.getNumLoops(), false);
185 
186     auto addToCoveredDims = [&](AffineMap map) {
187       for (auto result : map.getResults())
188         if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
189           coveredDims[dimExpr.getPosition()] = true;
190     };
191 
192     for (auto pair :
193          llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
194       Value operand = std::get<0>(pair);
195       if (operand == fusedOperand->get())
196         continue;
197       AffineMap operandMap = std::get<1>(pair);
198       addToCoveredDims(operandMap);
199     }
200 
201     for (OpOperand *operand : producer.getDpsInputOperands()) {
202       AffineMap newIndexingMap =
203           getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
204               operand, producerResultIndexMap, consumerIndexMap);
205       addToCoveredDims(newIndexingMap);
206     }
207     if (!coveredDims.all())
208       return false;
209   }
210 
211   return true;
212 }
213 
214 /// Generate the region of the fused tensor operation. The region of the fused
215 /// op must be empty.
216 static void generateFusedElementwiseOpRegion(
217     RewriterBase &rewriter, GenericOp fusedOp,
218     AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
219     unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
220   auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
221   auto consumer = cast<GenericOp>(fusedOperand->getOwner());
222   // Build the region of the fused op.
223   Block &producerBlock = producer->getRegion(0).front();
224   Block &consumerBlock = consumer->getRegion(0).front();
225   OpBuilder::InsertionGuard guard(rewriter);
226   Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
227   IRMapping mapper;
228 
229   // 2. Add an index operation for every fused loop dimension and use the
230   // `consumerToProducerLoopsMap` to map the producer indices.
231   if (producer.hasIndexSemantics()) {
232     // Add an index operation for every fused loop dimension.
233     unsigned numFusedOpLoops =
234         std::max(producer.getNumLoops(), consumer.getNumLoops());
235     SmallVector<Value> fusedIndices;
236     fusedIndices.reserve(numFusedOpLoops);
237     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
238                     std::back_inserter(fusedIndices), [&](uint64_t dim) {
239                       return rewriter.create<IndexOp>(producer.getLoc(), dim);
240                     });
241     for (IndexOp indexOp :
242          llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
243       Value newIndex = rewriter.create<affine::AffineApplyOp>(
244           producer.getLoc(),
245           consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
246       mapper.map(indexOp.getResult(), newIndex);
247     }
248   }
249   // TODO: allow fusing the producer of an output operand.
250   assert(consumer.isDpsInput(fusedOperand) &&
251          "expected producer of input operand");
252   // 3. Consumer input operands up to consumerIdx (exclusive).
253   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
254            fusedOperand->getOperandNumber())) // input assumption.
255     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
256 
257   // Replacing consumerIdx requires getting the cloned, yielded, value from
258   // the (cloned) producer block. This happens in step 9.
259 
260   // 4. Splice in producer's input operands.
261   for (BlockArgument bbArg :
262        producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
263     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
264 
265   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
266   for (BlockArgument bbArg :
267        consumerBlock.getArguments()
268            .take_front(consumer.getNumDpsInputs())
269            .drop_front(fusedOperand->getOperandNumber() + 1))
270     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
271 
272   // 6. All of the producer's output operands
273   for (const auto &bbArg : llvm::enumerate(
274            producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
275     if (!preservedProducerResults.count(bbArg.index()))
276       continue;
277     mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
278                                                       bbArg.value().getLoc()));
279   }
280 
281   // 7. All of consumer's output operands.
282   for (BlockArgument bbArg :
283        consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
284     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
285 
286   // 8. Clone all producer operations except for the yield and index operations
287   // to the fused operation.
288   for (auto &op : producerBlock.without_terminator()) {
289     if (!isa<IndexOp>(op))
290       rewriter.clone(op, mapper);
291   }
292   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
293   // forward the yield operand.
294   auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
295   unsigned producerResultNumber =
296       cast<OpResult>(fusedOperand->get()).getResultNumber();
297   Value replacement =
298       mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
299 
300   // Sanity checks, if replacement is not already in the mapper then it must be
301   // produced outside.
302   if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
303     if (auto bb = dyn_cast<BlockArgument>(replacement))
304       assert(bb.getOwner() != &producerBlock &&
305              "yielded block argument must have been mapped");
306     else
307       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
308              "yielded value must have been mapped");
309   }
310   mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
311              replacement);
312   // 10. Clone operations from the consumer to the fused op.
313   for (auto &op : consumerBlock.without_terminator())
314     rewriter.clone(op, mapper);
315 
316   // 11. Include the final yield (which is the remapped values for all the
317   // yield)
318   auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
319   SmallVector<Value> fusedYieldValues;
320   fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
321                            consumerYieldOp.getNumOperands());
322   for (const auto &producerYieldVal :
323        llvm::enumerate(producerYieldOp.getOperands())) {
324     if (preservedProducerResults.count(producerYieldVal.index()))
325       fusedYieldValues.push_back(
326           mapper.lookupOrDefault(producerYieldVal.value()));
327   }
328   for (auto consumerYieldVal : consumerYieldOp.getOperands())
329     fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
330   rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
331 
332   // Sanity checks.
333   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
334          "Ill-formed GenericOp region");
335 }
336 
337 FailureOr<mlir::linalg::ElementwiseOpFusionResult>
338 mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
339                                  OpOperand *fusedOperand) {
340   assert(areElementwiseOpsFusable(fusedOperand) &&
341          "expected elementwise operation pre-conditions to pass");
342   auto producerResult = cast<OpResult>(fusedOperand->get());
343   auto producer = cast<GenericOp>(producerResult.getOwner());
344   auto consumer = cast<GenericOp>(fusedOperand->getOwner());
345   // TODO: allow fusing the producer of an output operand.
346   assert(consumer.isDpsInput(fusedOperand) &&
347          "expected producer of input operand");
348   /// Find the results of the producer that have uses outside of the consumer,
349   /// after the fusion.
350   llvm::SmallDenseSet<int> preservedProducerResults =
351       mlir::linalg::getPreservedProducerResults(producer, consumer,
352                                                 fusedOperand);
353 
354   // Compute the fused operands list and indexing maps.
355   SmallVector<Value> fusedInputOperands, fusedOutputOperands;
356   SmallVector<Type> fusedResultTypes;
357   SmallVector<AffineMap> fusedIndexMaps;
358   fusedInputOperands.reserve(producer.getNumDpsInputs() +
359                              consumer.getNumDpsInputs());
360   fusedOutputOperands.reserve(preservedProducerResults.size() +
361                               consumer.getNumDpsInits());
362   fusedResultTypes.reserve(preservedProducerResults.size() +
363                            consumer.getNumDpsInits());
364   fusedIndexMaps.reserve(producer->getNumOperands() +
365                          consumer->getNumOperands());
366   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
367   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
368   auto consumerInputs = consumer.getDpsInputOperands();
369   auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
370     return operand == fusedOperand;
371   });
372   assert(it != consumerInputs.end() && "expected to find the consumer operand");
373   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
374     fusedInputOperands.push_back(opOperand->get());
375     fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
376   }
377   // 4. Splice in producer's input operands/maps.
378   AffineMap producerResultIndexMap =
379       producer.getIndexingMapMatchingResult(producerResult);
380   for (OpOperand *opOperand : producer.getDpsInputOperands()) {
381     fusedInputOperands.push_back(opOperand->get());
382     // Compute indexing maps for the producer args in the fused operation.
383     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
384         opOperand, producerResultIndexMap,
385         consumer.getMatchingIndexingMap(fusedOperand));
386     fusedIndexMaps.push_back(map);
387   }
388   // 5. Remaining consumer's input operands/maps (drop past index
389   // `consumerIdx`).
390   for (OpOperand *opOperand :
391        llvm::make_range(std::next(it), consumerInputs.end())) {
392     fusedInputOperands.push_back(opOperand->get());
393     fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
394   }
395 
396   // 6. Collect all of the producer outputs.
397   for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
398     if (!preservedProducerResults.count(opOperand.index()))
399       continue;
400 
401     fusedOutputOperands.push_back(opOperand.value().get());
402     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
403         &opOperand.value(), producerResultIndexMap,
404         consumer.getMatchingIndexingMap(fusedOperand));
405     fusedIndexMaps.push_back(map);
406     fusedResultTypes.push_back(opOperand.value().get().getType());
407   }
408 
409   // 7. All of consumer's output operands (skip operands: added by the builder).
410   for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
411     fusedOutputOperands.push_back(opOperand.get());
412     fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
413     Type resultType = opOperand.get().getType();
414     if (!isa<MemRefType>(resultType))
415       fusedResultTypes.push_back(resultType);
416   }
417 
418   // Generate the fused op.
419   auto fusedOp = rewriter.create<GenericOp>(
420       consumer.getLoc(), fusedResultTypes, fusedInputOperands,
421       fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
422       consumer.getIteratorTypes(),
423       /*doc=*/nullptr,
424       /*library_call=*/nullptr);
425   if (!fusedOp.getShapesToLoopsMap()) {
426     // Fused op has invalid indexing maps. Typically this means something is off
427     // in the input, but going ahead here would result in verification errors.
428     // So cleanup and abort.
429     rewriter.eraseOp(fusedOp);
430     return rewriter.notifyMatchFailure(
431         fusedOp, "fused op failed loop bound computation check");
432   }
433 
434   // Construct an AffineMap from consumer loops to producer loops.
435   // consumer loop -> tensor index
436   AffineMap consumerResultIndexMap =
437       consumer.getMatchingIndexingMap(fusedOperand);
438   // tensor index -> producer loop
439   AffineMap invProducerResultIndexMap =
440       inversePermutation(producerResultIndexMap);
441   assert(invProducerResultIndexMap &&
442          "expected producer result indexig map to be invertible");
443   // consumer loop -> producer loop
444   AffineMap consumerToProducerLoopsMap =
445       invProducerResultIndexMap.compose(consumerResultIndexMap);
446 
447   generateFusedElementwiseOpRegion(
448       rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
449       consumer.getNumLoops(), preservedProducerResults);
450   ElementwiseOpFusionResult result;
451   result.fusedOp = fusedOp;
452   int resultNum = 0;
453   for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
454     if (preservedProducerResults.count(index))
455       result.replacements[producerResult] = fusedOp->getResult(resultNum++);
456   for (auto consumerResult : consumer->getResults())
457     result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
458   return result;
459 }
460 
461 namespace {
462 /// Patterns to fuse a generic op, with the producer of its operands.
463 class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
464 public:
465   FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
466                      PatternBenefit benefit = 1)
467       : OpRewritePattern<GenericOp>(context, benefit),
468         controlFn(std::move(fun)) {}
469 
470   LogicalResult matchAndRewrite(GenericOp genericOp,
471                                 PatternRewriter &rewriter) const override {
472     // Find the first operand that is defined by another generic op on tensors.
473     for (OpOperand &opOperand : genericOp->getOpOperands()) {
474       if (!areElementwiseOpsFusable(&opOperand))
475         continue;
476       if (!controlFn(&opOperand))
477         continue;
478 
479       Operation *producer = opOperand.get().getDefiningOp();
480 
481       // Find the producer of the operand.
482       FailureOr<ElementwiseOpFusionResult> fusionResult =
483           fuseElementwiseOps(rewriter, &opOperand);
484       if (failed(fusionResult))
485         return rewriter.notifyMatchFailure(genericOp, "fusion failed");
486 
487       // Perform the fusion.
488       for (auto [origVal, replacement] : fusionResult->replacements) {
489         rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
490           // Only replace consumer uses.
491           return use.get().getDefiningOp() != producer;
492         });
493       }
494       rewriter.eraseOp(genericOp);
495       return success();
496     }
497     return failure();
498   }
499 
500 private:
501   ControlFusionFn controlFn;
502 };
503 } // namespace
504 
505 //===---------------------------------------------------------------------===//
506 // Methods and patterns that fuse reshape ops with elementwise operations by
507 // expanding the dimensionality of the elementwise operations.
508 //===---------------------------------------------------------------------===//
509 
510 /// Conditions for folding a structured linalg operation with a reshape op by
511 /// expanding the iteration space dimensionality for tensor operations. These
512 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements
513 /// the following fusion pattern.
514 ///
515 ///  Consider
516 ///
517 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
518 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
519 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
520 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
521 ///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
522 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
523 ///
524 ///  The reshape can be folded into the `linalgOp` if its loop dimensionality
525 ///  is increased to match the result (operand) of the tensor.expand_shape.
526 ///  The indexing_map of the fused tensor in the `linalgOp` and the
527 ///  reassociation map helps compute the indexing maps of the modified op.
528 ///  For the above example, based on the reassociation map it
529 ///  can be concluded that
530 ///
531 ///  - The loop used to access the first dimension of the fused tensor is split
532 ///    into two.
533 ///  - The loop used to access the second dimension of the fused tensor is kept
534 ///    as is.
535 ///  - The loop used to access the third dimension of the fused tensor is split
536 ///    into three.
537 ///
538 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
539 ///  op, then
540 ///
541 ///   d0 -> e0, e1
542 ///   d1 -> e2, e3, e4
543 ///   d2 -> e5
544 ///
545 ///  substituting this, the structured op can be rewritten as
546 ///
547 ///  %d = linalg.generic ins(%0, %1 : )
548 ///        indexing_maps =
549 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
550 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
551 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
552 ///
553 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
554 ///  to make it consistent
555 ///
556 ///  %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
557 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
558 ///  %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
559 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
560 ///
561 ///  The added reshapes are again expanding patterns, so they will get fused
562 ///  with its producers if possible.
563 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
564                                                OpOperand *fusableOpOperand) {
565   // Is fusable only if:
566   // - All the indexing maps for operands and results are projected
567   //   permutations.
568   // - The fused tensor is not a scalar.
569   // - All the loops for the reshaped operand are parallel loops.
570   SmallVector<utils::IteratorType> iteratorTypes =
571       linalgOp.getIteratorTypesArray();
572   AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
573   return linalgOp.hasPureTensorSemantics() &&
574          llvm::all_of(linalgOp.getIndexingMaps().getValue(),
575                       [](Attribute attr) {
576                         return cast<AffineMapAttr>(attr)
577                             .getValue()
578                             .isProjectedPermutation();
579                       }) &&
580          operandMap.getNumResults() > 0 &&
581          llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
582            return isParallelIterator(
583                iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
584          });
585 }
586 
587 namespace {
588 /// Information needed to expand a generic operation to fold the reshape with
589 /// it.
590 class ExpansionInfo {
591 public:
592   // Computes the mapping from original dimensions of the op to the dimensions
593   // of the expanded op given the `indexingMap` of the fused operand/result of
594   // the generic op, the `reassocationMaps` of the reshape op and the shape of
595   // the expanded op.
596   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
597                         ArrayRef<AffineMap> reassociationMaps,
598                         ArrayRef<int64_t> expandedShape,
599                         ArrayRef<int64_t> collapsedShape,
600                         PatternRewriter &rewriter);
601   unsigned getOrigOpNumDims() const { return reassociation.size(); }
602   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
603   ReassociationIndicesRef getExpandedDims(unsigned i) const {
604     return reassociation[i];
605   }
606   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
607     return expandedShapeMap[i];
608   }
609   ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
610 
611 private:
612   /// Reassociation from the dimensions in the original operation to the
613   /// dimension of the expanded operation.
614   SmallVector<ReassociationIndices> reassociation;
615   /// Mapping from extent of loops in the original operation, to the extent of
616   /// loops in the expanded operation.
617   SmallVector<SmallVector<int64_t>> expandedShapeMap;
618   /// Extent of the loop in the original operation.
619   SmallVector<int64_t> originalLoopExtent;
620   unsigned expandedOpNumDims;
621 };
622 } // namespace
623 
624 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
625                                      OpOperand *fusableOpOperand,
626                                      ArrayRef<AffineMap> reassociationMaps,
627                                      ArrayRef<int64_t> expandedShape,
628                                      ArrayRef<int64_t> collapsedShape,
629                                      PatternRewriter &rewriter) {
630   if (reassociationMaps.empty())
631     return failure();
632   AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
633 
634   SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
635   originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
636 
637   reassociation.clear();
638   expandedShapeMap.clear();
639   // Compute the number of dimension in the expanded op that correspond to each
640   // dimension of the original op.
641   SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
642   expandedShapeMap.resize(fusedIndexMap.getNumDims());
643   for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
644     unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
645     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
646     numExpandedDims[pos] = foldedDims.getNumResults();
647     ArrayRef<int64_t> shape =
648         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
649     expandedShapeMap[pos].assign(shape.begin(), shape.end());
650   }
651   // The remaining dimensions remain the same.
652   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
653     if (expandedShapeMap[i].empty())
654       expandedShapeMap[i] = {originalLoopExtent[i]};
655 
656   // Compute reassociation map from the original op to the expanded op.
657   unsigned sum = 0;
658   reassociation.reserve(fusedIndexMap.getNumDims());
659   for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
660     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
661     reassociation.emplace_back(seq.begin(), seq.end());
662     sum += numFoldedDim.value();
663   }
664   expandedOpNumDims = sum;
665   return success();
666 }
667 
668 /// Expanding the body of a linalg operation requires adaptations of the
669 /// accessed loop indices. Specifically, access of indices in the original
670 /// operation need to be replaced with linearizations of indices in the expanded
671 /// op. That requires the shape of the expanded dimensions to be static (at
672 /// least all but the most significant). For now check that these are all
673 /// statically sized. Note that this could be extended to handle dynamic case,
674 /// but the implementation below uses `affine.apply` which seems to have issues
675 /// when the shapes are not static.
676 static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
677                                           const ExpansionInfo &expansionInfo,
678                                           PatternRewriter &rewriter) {
679   if (!linalgOp.hasIndexSemantics())
680     return success();
681   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
682     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
683     if (expandedShape.size() == 1)
684       continue;
685     for (int64_t shape : expandedShape.drop_front()) {
686       if (ShapedType::isDynamic(shape)) {
687         return rewriter.notifyMatchFailure(
688             linalgOp, "cannot expand due to index semantics and dynamic dims");
689       }
690     }
691   }
692   return success();
693 }
694 
695 /// Return the indexing map to use in the expanded op for a given the
696 /// `indexingMap` of the original operation.
697 static AffineMap
698 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
699                            const ExpansionInfo &expansionInfo) {
700   SmallVector<AffineExpr> newExprs;
701   for (AffineExpr expr : indexingMap.getResults()) {
702     unsigned pos = cast<AffineDimExpr>(expr).getPosition();
703     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
704         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
705           return builder.getAffineDimExpr(static_cast<unsigned>(v));
706         }));
707     newExprs.append(expandedExprs.begin(), expandedExprs.end());
708   }
709   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
710                         indexingMap.getNumSymbols(), newExprs,
711                         builder.getContext());
712 }
713 
714 /// Return the type of the operand/result to use in the expanded op given the
715 /// type in the original op.
716 static RankedTensorType getExpandedType(RankedTensorType originalType,
717                                         AffineMap indexingMap,
718                                         const ExpansionInfo &expansionInfo) {
719   SmallVector<int64_t> expandedShape;
720   for (AffineExpr expr : indexingMap.getResults()) {
721     unsigned dim = cast<AffineDimExpr>(expr).getPosition();
722     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
723     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
724   }
725   return RankedTensorType::get(expandedShape, originalType.getElementType());
726 }
727 
728 /// Returns the reassociation maps to use in the `tensor.expand_shape`
729 /// operation to convert the operands of the original operation to operands of
730 /// the expanded operation. The same method is used to compute the
731 /// `tensor.collapse_shape` used to collapse the result of the expanded
732 /// op to get the value that can replace all uses of the results of the original
733 /// op.
734 static SmallVector<ReassociationIndices>
735 getReassociationForExpansion(AffineMap indexingMap,
736                              const ExpansionInfo &expansionInfo) {
737   SmallVector<ReassociationIndices> reassociation;
738   unsigned numReshapeDims = 0;
739   for (AffineExpr expr : indexingMap.getResults()) {
740     unsigned dim = cast<AffineDimExpr>(expr).getPosition();
741     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
742     SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
743         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
744     reassociation.emplace_back(std::move(indices));
745     numReshapeDims += numExpandedDims;
746   }
747   return reassociation;
748 }
749 
750 /// Update the body of an expanded linalg operation having index semantics. The
751 /// indices of the original operation need to be recovered by linearizing the
752 /// indices of the correspoding dimensions of the expanded operation. For now it
753 /// is assumed that the shapes of the expanded operation needed for
754 /// linearization are static.
755 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
756                                           Location loc, Region &fusedRegion,
757                                           const ExpansionInfo &expansionInfo) {
758   // Replace the original indices by the linearization of the expanded indices.
759   for (IndexOp indexOp :
760        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
761     ArrayRef<int64_t> expandedDims =
762         expansionInfo.getExpandedDims(indexOp.getDim());
763     assert(!expandedDims.empty() && "expected valid expansion info");
764 
765     // Skip index operations that are not affected by the expansion.
766     if (expandedDims.size() == 1 &&
767         expandedDims.front() == (int64_t)indexOp.getDim())
768       continue;
769 
770     // Linearize the expanded indices of the original index dimension.
771     OpBuilder::InsertionGuard guard(rewriter);
772     rewriter.setInsertionPointAfter(indexOp);
773     ArrayRef<int64_t> expandedDimsShape =
774         expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
775     SmallVector<Value> expandedIndices;
776     expandedIndices.reserve(expandedDims.size() - 1);
777     llvm::transform(
778         expandedDims.drop_front(), std::back_inserter(expandedIndices),
779         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
780     Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
781     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
782       assert(!ShapedType::isDynamic(std::get<0>(it)));
783       AffineExpr idx, acc;
784       bindDims(rewriter.getContext(), idx, acc);
785       newIndex = rewriter.create<affine::AffineApplyOp>(
786           indexOp.getLoc(), idx + acc * std::get<0>(it),
787           ValueRange{std::get<1>(it), newIndex});
788     }
789     rewriter.replaceOp(indexOp, newIndex);
790   }
791 }
792 
793 /// Checks if a single dynamic dimension expanded into multiple dynamic
794 /// dimensions.
795 static LogicalResult
796 validateDynamicDimExpansion(LinalgOp linalgOp,
797                             const ExpansionInfo &expansionInfo,
798                             PatternRewriter &rewriter) {
799   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
800     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
801     if (expandedShape.size() == 1)
802       continue;
803     bool foundDynamic = false;
804     for (int64_t shape : expandedShape) {
805       if (!ShapedType::isDynamic(shape))
806         continue;
807       if (foundDynamic) {
808         return rewriter.notifyMatchFailure(
809             linalgOp, "cannot infer expanded shape with multiple dynamic "
810                       "dims in the same reassociation group");
811       }
812       foundDynamic = true;
813     }
814   }
815   return success();
816 }
817 
818 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
819 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
820 /// that those conditions have been satisfied.
821 static std::optional<SmallVector<Value>>
822 fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
823                            OpOperand *fusableOpOperand,
824                            PatternRewriter &rewriter) {
825   assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
826          "preconditions for fuse operation failed");
827 
828   Location loc = linalgOp.getLoc();
829   // Check if reshape is expanding or collapsing.
830   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
831   auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
832   bool isExpanding = (expandingReshapeOp != nullptr);
833   RankedTensorType expandedType = isExpanding
834                                       ? expandingReshapeOp.getResultType()
835                                       : collapsingReshapeOp.getSrcType();
836   RankedTensorType collapsedType = isExpanding
837                                        ? expandingReshapeOp.getSrcType()
838                                        : collapsingReshapeOp.getResultType();
839 
840   ExpansionInfo expansionInfo;
841   if (failed(expansionInfo.compute(
842           linalgOp, fusableOpOperand,
843           isExpanding ? expandingReshapeOp.getReassociationMaps()
844                       : collapsingReshapeOp.getReassociationMaps(),
845           expandedType.getShape(), collapsedType.getShape(), rewriter)))
846     return std::nullopt;
847 
848   // TODO: With the support of multiple dynamic dims expansion in
849   // tensor.expand_shape op, this case can be handled.
850   if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
851     return std::nullopt;
852 
853   if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
854     return std::nullopt;
855 
856   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
857       llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
858         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
859       }));
860 
861   // Set insertion point to the generic op.
862   OpBuilder::InsertionGuard g(rewriter);
863   rewriter.setInsertionPoint(linalgOp);
864 
865   SmallVector<Value> expandedOpOperands;
866   expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
867   for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
868     if (opOperand == fusableOpOperand) {
869       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
870                                                : collapsingReshapeOp.getSrc());
871       continue;
872     }
873     if (auto opOperandType =
874             dyn_cast<RankedTensorType>(opOperand->get().getType())) {
875       AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
876       RankedTensorType expandedOperandType =
877           getExpandedType(opOperandType, indexingMap, expansionInfo);
878       if (expandedOperandType != opOperand->get().getType()) {
879         // Reshape the operand to get the right type.
880         SmallVector<ReassociationIndices> reassociation =
881             getReassociationForExpansion(indexingMap, expansionInfo);
882         if (failed(reshapeLikeShapesAreCompatible(
883                 [&](const Twine &msg) {
884                   return rewriter.notifyMatchFailure(linalgOp, msg);
885                 },
886                 opOperandType.getShape(), expandedOperandType.getShape(),
887                 reassociation,
888                 /*isExpandingReshape=*/true)))
889           return std::nullopt;
890         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
891             loc, expandedOperandType, opOperand->get(), reassociation));
892         continue;
893       }
894     }
895     expandedOpOperands.push_back(opOperand->get());
896   }
897 
898   SmallVector<Value> outputs;
899   for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
900     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
901     auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
902     RankedTensorType expandedOutputType =
903         getExpandedType(opOperandType, indexingMap, expansionInfo);
904     if (expandedOutputType != opOperand.get().getType()) {
905       SmallVector<ReassociationIndices> reassociation =
906           getReassociationForExpansion(indexingMap, expansionInfo);
907       if (failed(reshapeLikeShapesAreCompatible(
908               [&](const Twine &msg) {
909                 return rewriter.notifyMatchFailure(linalgOp, msg);
910               },
911               opOperandType.getShape(), expandedOutputType.getShape(),
912               reassociation,
913               /*isExpandingReshape=*/true)))
914         return std::nullopt;
915       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
916           loc, expandedOutputType, opOperand.get(), reassociation));
917     } else {
918       outputs.push_back(opOperand.get());
919     }
920   }
921 
922   // The iterator types of the expanded op are all parallel.
923   SmallVector<utils::IteratorType> iteratorTypes(
924       expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
925   for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
926     for (auto j : expansionInfo.getExpandedDims(i))
927       iteratorTypes[j] = type;
928 
929   TypeRange resultTypes = ValueRange(outputs).getTypes();
930   auto fusedOp =
931       rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
932                                  /*inputs=*/expandedOpOperands, outputs,
933                                  expandedOpIndexingMaps, iteratorTypes);
934   Region &fusedRegion = fusedOp->getRegion(0);
935   Region &originalRegion = linalgOp->getRegion(0);
936   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
937 
938   // Update the index accesses after the expansion.
939   updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
940 
941   // Reshape the result values to their original shape if this is a collapsing
942   // reshape folded into its consumer.
943   SmallVector<Value> resultVals;
944   for (OpResult opResult : linalgOp->getOpResults()) {
945     int64_t resultNumber = opResult.getResultNumber();
946     if (resultTypes[resultNumber] != opResult.getType()) {
947       SmallVector<ReassociationIndices> reassociation =
948           getReassociationForExpansion(
949               linalgOp.getMatchingIndexingMap(
950                   linalgOp.getDpsInitOperand(resultNumber)),
951               expansionInfo);
952       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
953           linalgOp.getLoc(), opResult.getType(),
954           fusedOp->getResult(resultNumber), reassociation));
955     } else {
956       resultVals.push_back(fusedOp->getResult(resultNumber));
957     }
958   }
959   // Assuming a single result.
960   return resultVals;
961 }
962 
963 namespace {
964 
965 /// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
966 /// when the reshape op is collapsing dimensions. The dimensionality of the loop
967 /// in the consumer is expanded.
968 class FoldWithProducerReshapeOpByExpansion
969     : public OpInterfaceRewritePattern<LinalgOp> {
970 public:
971   FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
972                                        ControlFusionFn foldReshapes,
973                                        PatternBenefit benefit = 1)
974       : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
975         controlFoldingReshapes(std::move(foldReshapes)) {}
976 
977   LogicalResult matchAndRewrite(LinalgOp linalgOp,
978                                 PatternRewriter &rewriter) const override {
979     for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
980       tensor::CollapseShapeOp reshapeOp =
981           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
982       if (!reshapeOp)
983         continue;
984       // Fold only if
985       // - The tensor reshape op is folding.
986       // - All constraints of fusing with reshape by expansion are met.
987       if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
988           (!controlFoldingReshapes(opOperand)))
989         continue;
990 
991       std::optional<SmallVector<Value>> replacementValues =
992           fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
993       if (!replacementValues)
994         return failure();
995       rewriter.replaceOp(linalgOp, *replacementValues);
996       return success();
997     }
998     return failure();
999   }
1000 
1001 private:
1002   ControlFusionFn controlFoldingReshapes;
1003 };
1004 
1005 class FoldPadWithProducerReshapeOpByExpansion
1006     : public OpRewritePattern<tensor::PadOp> {
1007 public:
1008   FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1009                                           ControlFusionFn foldReshapes,
1010                                           PatternBenefit benefit = 1)
1011       : OpRewritePattern<tensor::PadOp>(context, benefit),
1012         controlFoldingReshapes(std::move(foldReshapes)) {}
1013 
1014   LogicalResult matchAndRewrite(tensor::PadOp padOp,
1015                                 PatternRewriter &rewriter) const override {
1016     tensor::CollapseShapeOp reshapeOp =
1017         padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1018     if (!reshapeOp)
1019       return failure();
1020     if (!reshapeOp->hasOneUse())
1021       return failure();
1022 
1023     if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1024       return rewriter.notifyMatchFailure(padOp,
1025                                          "fusion blocked by control function");
1026     }
1027 
1028     ArrayRef<int64_t> low = padOp.getStaticLow();
1029     ArrayRef<int64_t> high = padOp.getStaticHigh();
1030     SmallVector<ReassociationIndices> reassociations =
1031         reshapeOp.getReassociationIndices();
1032 
1033     for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1034       if (reInd.size() != 1 && (l != 0 || h != 0))
1035         return failure();
1036     }
1037 
1038     SmallVector<OpFoldResult> newLow, newHigh;
1039     RankedTensorType expandedType = reshapeOp.getSrcType();
1040     RankedTensorType paddedType = padOp.getResultType();
1041     SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1042     for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1043       if (reInd.size() == 1) {
1044         expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1045       }
1046       for (size_t i = 0; i < reInd.size(); ++i) {
1047         newLow.push_back(padOp.getMixedLowPad()[idx]);
1048         newHigh.push_back(padOp.getMixedHighPad()[idx]);
1049       }
1050     }
1051 
1052     Location loc = padOp->getLoc();
1053     RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1054     auto newPadOp = rewriter.create<tensor::PadOp>(
1055         loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1056         padOp.getConstantPaddingValue(), padOp.getNofold());
1057 
1058     rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1059         padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1060 
1061     return success();
1062   }
1063 
1064 private:
1065   ControlFusionFn controlFoldingReshapes;
1066 };
1067 
1068 /// Pattern to fold a tensor.expand_shape op with its producer generic op
1069 /// by expanding the dimensionality of the loop in the producer op.
1070 struct FoldReshapeWithGenericOpByExpansion
1071     : public OpRewritePattern<tensor::ExpandShapeOp> {
1072 
1073   FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
1074                                       ControlFusionFn foldReshapes,
1075                                       PatternBenefit benefit = 1)
1076       : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
1077         controlFoldingReshapes(std::move(foldReshapes)) {}
1078 
1079   LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
1080                                 PatternRewriter &rewriter) const override {
1081     // Fold only if all constraints of fusing with reshape by expansion are met.
1082     auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1083     if (!producerResult) {
1084       return rewriter.notifyMatchFailure(reshapeOp,
1085                                          "source not produced by an operation");
1086     }
1087 
1088     auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1089     if (!producer) {
1090       return rewriter.notifyMatchFailure(reshapeOp,
1091                                          "producer not a generic op");
1092     }
1093 
1094     if (!isFusableWithReshapeByDimExpansion(
1095             producer,
1096             producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1097       return rewriter.notifyMatchFailure(
1098           reshapeOp, "failed preconditions of fusion with producer generic op");
1099     }
1100 
1101     if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1102       return rewriter.notifyMatchFailure(reshapeOp,
1103                                          "fusion blocked by control function");
1104     }
1105 
1106     std::optional<SmallVector<Value>> replacementValues =
1107         fuseWithReshapeByExpansion(
1108             producer, reshapeOp,
1109             producer.getDpsInitOperand(producerResult.getResultNumber()),
1110             rewriter);
1111     if (!replacementValues) {
1112       return rewriter.notifyMatchFailure(reshapeOp,
1113                                          "fusion by expansion failed");
1114     }
1115 
1116     // Find the replacement for the reshape op. Since the replacements have the
1117     // same type as the returns of the original generic op, the consumer reshape
1118     // op can be replaced by the source of the collapse_shape op that defines
1119     // the replacement.
1120     Value reshapeReplacement =
1121         (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
1122                                  .getResultNumber()];
1123     if (auto collapseOp =
1124             reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1125       reshapeReplacement = collapseOp.getSrc();
1126     }
1127     rewriter.replaceOp(reshapeOp, reshapeReplacement);
1128     rewriter.replaceOp(producer, *replacementValues);
1129     return success();
1130   }
1131 
1132 private:
1133   ControlFusionFn controlFoldingReshapes;
1134 };
1135 } // namespace
1136 
1137 //===---------------------------------------------------------------------===//
1138 // Methods and patterns to fuse reshape with linalg.generic operations by
1139 // contraction of dimensions.
1140 //===---------------------------------------------------------------------===//
1141 
1142 /// For a given list of indices in the range of the `indexingMap` that are
1143 /// folded, return the indices of the corresponding domain. Return
1144 /// `std::nullopt` on failure. Ensures that all the elements of the returned
1145 /// reassociation are distinct.
1146 static ReassociationIndices
1147 getDomainReassociation(AffineMap indexingMap,
1148                        ReassociationIndicesRef rangeReassociation) {
1149   assert(indexingMap.isProjectedPermutation() &&
1150          "expected projected permutation");
1151 
1152   ReassociationIndices domainReassociation = llvm::to_vector<4>(
1153       llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
1154         return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1155       }));
1156   // The projected permutation semantics ensures that there is no repetition of
1157   // the domain indices.
1158   return domainReassociation;
1159 }
1160 
1161 /// For a given `dimSequence`, check if the sequence is conserved in the
1162 /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
1163 /// Non-existence of the sequence returns true as well.
1164 bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
1165                                           ReassociationIndicesRef dimSequence) {
1166   assert(!dimSequence.empty() &&
1167          "expected non-empty list for dimension sequence");
1168   assert(indexingMap.isProjectedPermutation() &&
1169          "expected indexing map to be projected permutation");
1170 
1171   llvm::SmallDenseSet<unsigned, 4> sequenceElements;
1172   sequenceElements.insert(dimSequence.begin(), dimSequence.end());
1173 
1174   unsigned dimSequenceStart = dimSequence[0];
1175   for (const auto &expr : enumerate(indexingMap.getResults())) {
1176     unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
1177     // 1.  Check if this start of the sequence.
1178     if (dimInMapStart == dimSequenceStart) {
1179       if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
1180         return false;
1181       // 1a. Check if sequence is preserved.
1182       for (const auto &dimInSequence : enumerate(dimSequence)) {
1183         unsigned dimInMap =
1184             cast<AffineDimExpr>(
1185                 indexingMap.getResult(expr.index() + dimInSequence.index()))
1186                 .getPosition();
1187         if (dimInMap != dimInSequence.value())
1188           return false;
1189       }
1190       // Found the sequence. Projected permutation
1191       // enforces that all AffineDimExprs in the result are unique, so no
1192       // further checks are needed.
1193       return true;
1194     }
1195     // 2. If position in the expr (which is of type AffineDimExpr) is part
1196     // of sequence, return false here. This implies the entire sequence does not
1197     // exist in the indexing map.
1198     if (sequenceElements.count(dimInMapStart))
1199       return false;
1200   }
1201   // 3. No element of sequence found. Return true.
1202   return true;
1203 }
1204 
1205 bool mlir::linalg::areDimSequencesPreserved(
1206     ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) {
1207   return llvm::all_of(maps, [&](AffineMap map) {
1208     return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1209       return isDimSequencePreserved(map, dimSequence);
1210     });
1211   });
1212 }
1213 
1214 // Return the list of dimensions of the iteration domain that can be
1215 // collapsed to allow for fusion with the a producer that is an expand_shape
1216 // operation. If all dimensions created by expansion can be collapsed in the
1217 // iteration space then the reshape is defunct.
1218 //
1219 // Example:
1220 //
1221 // ```mlir
1222 // #map = affine_map<(d0, d1) -> (d0, d1)>
1223 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1224 // %2 = tensor.empty [..] : tensor<?x4xf32>
1225 // %3 = linalg.generic {
1226 //     indexing_maps = [#map, #map],
1227 //     iterator_types = ["parallel" ,"parallel"]}
1228 //     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1229 // ```
1230 //
1231 // can be fused by collapsing the dimensions of the iteration space.
1232 //
1233 // ```mlir
1234 // #map = affine_map<(d0) -> (d0)>
1235 // %2 = tensor.empty [..] : tensor<?xf32>
1236 // %3 = linalg.generic {
1237 //     indexing_maps = [#map, #map],
1238 //     iterator_types = ["parallel"]}
1239 //     ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1240 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1241 // ```
1242 //
1243 // In the following example,
1244 //
1245 // ```mlir
1246 // #map0 = affine_map<(d0, d1) -> (d0, d1)>
1247 // #map1 = affine_map<(d0, d1) -> (d1, d0)>
1248 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1249 // %2 = tensor.empty [..] : tensor<4x?xf32>
1250 // %2 = linalg.generic {
1251 //     indexing_maps = [#map0, #map1],
1252 //     iterator_types = ["parallel" ,"parallel"]}
1253 //     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1254 // ```
1255 //
1256 // the reshape cannot be fused with the generic op by collapsing the op
1257 // dimensions since the indexing maps will have to contain mods and divs
1258 // to preserve the accesses pattern. When no dimensions of the iteration
1259 // space are collapsable and empty vector is returned.
1260 static SmallVector<ReassociationIndices>
1261 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
1262                                  ArrayRef<ReassociationIndices> reassociation) {
1263   // Some basic checks for this fusion to be valid.
1264   if (!genericOp.hasPureTensorSemantics())
1265     return {};
1266 
1267   if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
1268         return map.isProjectedPermutation();
1269       })) {
1270     return {};
1271   }
1272 
1273   // Compute all the loops with the reduction iterator types.
1274   SmallVector<unsigned> reductionDims;
1275   genericOp.getReductionDims(reductionDims);
1276 
1277   llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
1278   AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1279   auto iteratorTypes = genericOp.getIteratorTypesArray();
1280   SmallVector<ReassociationIndices> iterationSpaceReassociation;
1281   for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1282     assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1283 
1284     // Ignore dims that are not folded.
1285     if (foldedRangeDims.size() == 1)
1286       continue;
1287 
1288     ReassociationIndices foldedIterationSpaceDims =
1289         getDomainReassociation(indexingMap, foldedRangeDims);
1290 
1291     // Check that the folded iteration dims do not contain already processed
1292     // dims.
1293     if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1294           return processedIterationDims.count(dim);
1295         }))
1296       continue;
1297 
1298     // Check that all folded iterator types are all parallel or all reductions.
1299     utils::IteratorType startIteratorType =
1300         iteratorTypes[foldedIterationSpaceDims[0]];
1301     if (!isParallelIterator(startIteratorType) &&
1302         !isReductionIterator(startIteratorType))
1303       continue;
1304     if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1305           return iteratorTypes[dim] != startIteratorType;
1306         }))
1307       continue;
1308 
1309     // If the folded dimensions correspond to a "reduction" iterator type,
1310     // the folded dimensions need to be "in-order". Strictly speaking this is
1311     // not necessary, for reductions that are associative and commutative,  but
1312     // using a more strict definition of reduction for now.
1313     if (isReductionIterator(startIteratorType)) {
1314       bool isContiguous = false;
1315       for (const auto &startDim : llvm::enumerate(reductionDims)) {
1316         // Move window in `reductionDims` to start of the folded iteration dims.
1317         if (startDim.value() != foldedIterationSpaceDims[0])
1318           continue;
1319         // If sizes doesnt match, trivial not contiguous. This condition should
1320         // not be hit.
1321         if (startDim.index() + foldedIterationSpaceDims.size() >
1322             reductionDims.size())
1323           break;
1324         // Check that the contiguity is maintained.
1325         isContiguous = true;
1326         for (const auto &foldedDim :
1327              llvm::enumerate(foldedIterationSpaceDims)) {
1328           if (reductionDims[foldedDim.index() + startDim.index()] !=
1329               foldedDim.value()) {
1330             isContiguous = false;
1331             break;
1332           }
1333         }
1334         break;
1335       }
1336       if (!isContiguous)
1337         continue;
1338     }
1339 
1340     // Check that the sequence is preserved in all indexing maps.
1341     if (llvm::any_of(genericOp.getIndexingMapsArray(),
1342                      [&](AffineMap indexingMap) {
1343                        return !isDimSequencePreserved(indexingMap,
1344                                                       foldedIterationSpaceDims);
1345                      }))
1346       continue;
1347 
1348     processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1349                                   foldedIterationSpaceDims.end());
1350     iterationSpaceReassociation.emplace_back(
1351         std::move(foldedIterationSpaceDims));
1352   }
1353 
1354   return iterationSpaceReassociation;
1355 }
1356 
1357 /// Helper class to carry state while collapsing the `linalg.generic` op.
1358 namespace {
1359 class CollapsingInfo {
1360 public:
1361   LogicalResult initialize(unsigned origNumLoops,
1362                            ArrayRef<ReassociationIndices> foldedIterationDims) {
1363     llvm::SmallDenseSet<int64_t, 4> processedDims;
1364     // Find all the dims that are folded.
1365     for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1366       if (foldedIterationDim.empty())
1367         continue;
1368       // If the folded dims contain dims already folded, that's illegal
1369       // specification. Repetition within a list is also illegal.
1370       for (auto dim : foldedIterationDim) {
1371         if (dim >= origNumLoops)
1372           return failure();
1373         if (processedDims.count(dim))
1374           return failure();
1375         processedDims.insert(dim);
1376       }
1377       collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1378                                                    foldedIterationDim.end());
1379     }
1380     if (processedDims.size() > origNumLoops)
1381       return failure();
1382 
1383     // Add all the preserved dims of the original op as single
1384     // elements to `collapsedOpToOrigOpIterationDim`.
1385     for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1386       if (processedDims.count(dim))
1387         continue;
1388       collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
1389     }
1390 
1391     llvm::sort(collapsedOpToOrigOpIterationDim,
1392                [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
1393                  return lhs[0] < rhs[0];
1394                });
1395     origOpToCollapsedOpIterationDim.resize(origNumLoops);
1396     for (const auto &foldedDims :
1397          llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
1398       for (const auto &dim : enumerate(foldedDims.value()))
1399         origOpToCollapsedOpIterationDim[dim.value()] =
1400             std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1401     }
1402     return success();
1403   }
1404 
1405   /// Return mapping from collapsed loop domain to original loop domain.
1406   ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1407     return collapsedOpToOrigOpIterationDim;
1408   }
1409 
1410   /// Return mapping from original loop domain to collapsed loop domain. The
1411   /// mapping is a pair. First value is the dimension in the collapsed loop that
1412   /// the original loop is mapped to. Second is the relative position in folded
1413   /// list of this domain. For example if the original loop domain is 3D, and
1414   /// the collapsed loop domain is folding all of it, i.e.
1415   ///
1416   /// ```
1417   /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1418   /// ```
1419   ///
1420   /// then
1421   ///
1422   /// ```
1423   ///  origOpToCollapsedOpMapping[0] = {0, 0};
1424   ///  origOpToCollapsedOpMapping[1] = {0, 1};
1425   ///  origOpToCollapsedOpMapping[2] = {0, 2};
1426   ///  origOpToCollapsedOpMapping[3] = {1, 0};
1427   ///  origOpToCollapsedOpMapping[4] = {1, 1};
1428   /// ```
1429   ///
1430   ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1431     return origOpToCollapsedOpIterationDim;
1432   }
1433 
1434   /// Return the collapsed op iteration domain rank.
1435   unsigned getCollapsedOpIterationRank() const {
1436     return collapsedOpToOrigOpIterationDim.size();
1437   }
1438 
1439 private:
1440   /// Map from the iteration domain index in collapsed op to the iteration
1441   /// domain indices in the original op.
1442   SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
1443 
1444   /// Map from iteration domain index in the original op to the iteration domain
1445   /// index in the collapsed op.
1446   SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
1447 };
1448 } // namespace
1449 
1450 /// Get the iterator types for the collapsed operation given the original
1451 /// iterator types and collapsed dimensions.
1452 static SmallVector<utils::IteratorType>
1453 getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1454                             const CollapsingInfo &collapsingInfo) {
1455   SmallVector<utils::IteratorType> collapsedIteratorTypes;
1456   for (ReassociationIndicesRef foldedIterDims :
1457        collapsingInfo.getCollapsedOpToOrigOpMapping()) {
1458     assert(!foldedIterDims.empty() &&
1459            "reassociation indices expected to have non-empty sets");
1460     // Just pick the iterator type of the first folded dim. Pre-condition checks
1461     // expected to have checked that iterator types of all folded dimensions are
1462     // the same.
1463     collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
1464   }
1465   return collapsedIteratorTypes;
1466 }
1467 
1468 /// Compute the indexing map in the collapsed op that corresponds to the given
1469 /// `indexingMap` of the original operation.
1470 static AffineMap
1471 getCollapsedOpIndexingMap(AffineMap indexingMap,
1472                           const CollapsingInfo &collapsingInfo) {
1473   MLIRContext *context = indexingMap.getContext();
1474   assert(indexingMap.isProjectedPermutation() &&
1475          "expected indexing map to be projected permutation");
1476   SmallVector<AffineExpr> resultExprs;
1477   auto origOpToCollapsedOpMapping =
1478       collapsingInfo.getOrigOpToCollapsedOpMapping();
1479   for (auto expr : indexingMap.getResults()) {
1480     unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1481     // If the dim is not the first of the collapsed dim, do nothing.
1482     if (origOpToCollapsedOpMapping[dim].second != 0)
1483       continue;
1484     // The next n-dims are guaranteed to be collapsed. So just use the
1485     // iteration dimension of the collapsed op.
1486     resultExprs.push_back(
1487         getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
1488   }
1489   return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
1490                         resultExprs, context);
1491 }
1492 
1493 /// Return the `reassociation` indices to use to collapse the operand when the
1494 /// iteration space of a generic op is collapsed.
1495 static SmallVector<ReassociationIndices>
1496 getOperandReassociation(AffineMap indexingMap,
1497                         const CollapsingInfo &collapsingInfo) {
1498   unsigned counter = 0;
1499   SmallVector<ReassociationIndices> operandReassociation;
1500   auto origOpToCollapsedOpMapping =
1501       collapsingInfo.getOrigOpToCollapsedOpMapping();
1502   auto collapsedOpToOrigOpMapping =
1503       collapsingInfo.getCollapsedOpToOrigOpMapping();
1504   while (counter < indexingMap.getNumResults()) {
1505     unsigned dim =
1506         cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1507     // This is the start of a collapsed dimensions of the iteration that
1508     // is gauranteed to be preserved in the indexing map. The number of folded
1509     // dims is obtained from the collapsed op to original op mapping.
1510     unsigned numFoldedDims =
1511         collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1512             .size();
1513     if (origOpToCollapsedOpMapping[dim].second == 0) {
1514       auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
1515       operandReassociation.emplace_back(range.begin(), range.end());
1516     }
1517     counter += numFoldedDims;
1518   }
1519   return operandReassociation;
1520 }
1521 
1522 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
1523 static Value getCollapsedOpOperand(Location loc, LinalgOp op,
1524                                    OpOperand *opOperand,
1525                                    const CollapsingInfo &collapsingInfo,
1526                                    OpBuilder &builder) {
1527   AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1528   SmallVector<ReassociationIndices> operandReassociation =
1529       getOperandReassociation(indexingMap, collapsingInfo);
1530 
1531   // If the number of entries in the reassociation for the operand is same as
1532   // the number of results of the indexing map, then nothing to do for this
1533   // operand.
1534   Value operand = opOperand->get();
1535   if (operandReassociation.size() == indexingMap.getNumResults())
1536     return operand;
1537 
1538   // Insert a reshape to collapse the dimensions.
1539   if (isa<MemRefType>(operand.getType())) {
1540     return builder
1541         .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1542         .getResult();
1543   }
1544   return builder
1545       .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1546       .getResult();
1547 }
1548 
1549 /// Modify the `linalg.index` operations in the original generic op, to its
1550 /// value in the collapsed operation.
1551 void generateCollapsedIndexingRegion(Location loc, Block *block,
1552                                      const CollapsingInfo &collapsingInfo,
1553                                      ValueRange loopRange,
1554                                      RewriterBase &rewriter) {
1555   OpBuilder::InsertionGuard g(rewriter);
1556   rewriter.setInsertionPointToStart(block);
1557 
1558   // Collect all the original index ops.
1559   auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
1560 
1561   // For each folded dimension list resolve the original induction variable
1562   // values in terms of the folded dimension induction variable.
1563   //   i_{folded} = (i_0 * d1 + i1) * d2 + i2.
1564   // can be inverted to
1565   //   i2 = i_{folded} % d2
1566   //   i1 = (i_{folded} / d2) % d1
1567   //   i0 = i_{folded} / (d1 * d2)
1568   llvm::DenseMap<unsigned, Value> indexReplacementVals;
1569   for (auto foldedDims :
1570        enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
1571     ReassociationIndicesRef foldedDimsRef(foldedDims.value());
1572     Value newIndexVal =
1573         rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
1574     for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
1575       indexReplacementVals[dim] =
1576           rewriter.create<arith::RemSIOp>(loc, newIndexVal, loopRange[dim]);
1577       newIndexVal =
1578           rewriter.create<arith::DivSIOp>(loc, newIndexVal, loopRange[dim]);
1579     }
1580     indexReplacementVals[foldedDims.value().front()] = newIndexVal;
1581   }
1582 
1583   for (auto indexOp : indexOps) {
1584     auto dim = indexOp.getDim();
1585     rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
1586   }
1587 }
1588 
1589 void collapseOperandsAndResults(LinalgOp op,
1590                                 const CollapsingInfo &collapsingInfo,
1591                                 RewriterBase &rewriter,
1592                                 SmallVectorImpl<Value> &inputOperands,
1593                                 SmallVectorImpl<Value> &outputOperands,
1594                                 SmallVectorImpl<Type> &resultTypes) {
1595   Location loc = op->getLoc();
1596   inputOperands =
1597       llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
1598         return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
1599                                      rewriter);
1600       });
1601 
1602   // Get the output operands and result types.
1603   resultTypes.reserve(op.getNumDpsInits());
1604   outputOperands.reserve(op.getNumDpsInits());
1605   for (OpOperand &output : op.getDpsInitsMutable()) {
1606     Value newOutput =
1607         getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
1608     outputOperands.push_back(newOutput);
1609     // If the op has "buffer semantics", then the init operands are ranked
1610     // memrefs and the op has no results.
1611     if (!op.hasPureBufferSemantics())
1612       resultTypes.push_back(newOutput.getType());
1613   }
1614 }
1615 
1616 /// Clone a `LinalgOp` to a collapsed version of same name
1617 template <typename OpTy>
1618 OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1619                         const CollapsingInfo &collapsingInfo) {
1620   return nullptr;
1621 }
1622 
1623 /// Collapse any `LinalgOp` that does not require any specialization such as
1624 /// indexing_maps, iterator_types, etc.
1625 template <>
1626 LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1627                                       const CollapsingInfo &collapsingInfo) {
1628   SmallVector<Value> inputOperands, outputOperands;
1629   SmallVector<Type> resultTypes;
1630   collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1631                              outputOperands, resultTypes);
1632 
1633   return clone(
1634       rewriter, origOp, resultTypes,
1635       llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1636 }
1637 
1638 /// Collapse a `GenericOp`
1639 template <>
1640 GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1641                                         GenericOp origOp,
1642                                         const CollapsingInfo &collapsingInfo) {
1643   SmallVector<Value> inputOperands, outputOperands;
1644   SmallVector<Type> resultTypes;
1645   collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1646                              outputOperands, resultTypes);
1647   SmallVector<AffineMap> indexingMaps(
1648       llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
1649         return getCollapsedOpIndexingMap(map, collapsingInfo);
1650       }));
1651 
1652   SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1653       origOp.getIteratorTypesArray(), collapsingInfo));
1654 
1655   GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
1656       origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
1657       iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1658   Block *origOpBlock = &origOp->getRegion(0).front();
1659   Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
1660   rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
1661                        collapsedOpBlock->getArguments());
1662   return collapsedOp;
1663 }
1664 
1665 LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
1666                            RewriterBase &rewriter) {
1667   if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1668     return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1669   } else {
1670     return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1671   }
1672 }
1673 
1674 /// Implementation of fusion with reshape operation by collapsing dimensions.
1675 FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1676     LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1677     RewriterBase &rewriter) {
1678   // Bail on trivial no-op cases.
1679   if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1680       llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1681         return foldedDims.size() <= 1;
1682       }))
1683     return failure();
1684 
1685   bool hasPureBufferSemantics = op.hasPureBufferSemantics();
1686   if (hasPureBufferSemantics &&
1687       !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
1688         MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1689         if (!memRefToCollapse)
1690           return true;
1691 
1692         return memref::CollapseShapeOp::isGuaranteedCollapsible(
1693             memRefToCollapse, foldedIterationDims);
1694       }))
1695     return rewriter.notifyMatchFailure(op,
1696                                        "memref is not guaranteed collapsible");
1697 
1698   CollapsingInfo collapsingInfo;
1699   if (failed(
1700           collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1701     return rewriter.notifyMatchFailure(
1702         op, "illegal to collapse specified dimensions");
1703   }
1704 
1705   // Bail on non-canonical ranges.
1706   SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
1707   auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
1708     if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
1709       return cast<IntegerAttr>(attr).getInt() == value;
1710     llvm::APInt actual;
1711     return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
1712            actual.getSExtValue() == value;
1713   };
1714   if (!llvm::all_of(loopRanges, [&](Range range) {
1715         return opFoldIsConstantValue(range.offset, 0) &&
1716                opFoldIsConstantValue(range.stride, 1);
1717       })) {
1718     return rewriter.notifyMatchFailure(
1719         op, "expected all loop ranges to have zero start and unit stride");
1720   }
1721 
1722   LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
1723 
1724   Location loc = op->getLoc();
1725   if (collapsedOp.hasIndexSemantics()) {
1726     // Collect the loop range of the generic op.
1727     OpBuilder::InsertionGuard g(rewriter);
1728     rewriter.setInsertionPoint(collapsedOp);
1729     SmallVector<Value> loopBound =
1730         llvm::map_to_vector(loopRanges, [&](Range range) {
1731           return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
1732         });
1733     generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
1734                                     collapsingInfo, loopBound, rewriter);
1735   }
1736 
1737   // Insert expanding reshape for the result to get back the original result
1738   // type.
1739   SmallVector<Value> results;
1740   for (const auto &originalResult : llvm::enumerate(op->getResults())) {
1741     Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
1742     auto originalResultType =
1743         cast<ShapedType>(originalResult.value().getType());
1744     auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
1745     if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
1746       AffineMap indexingMap =
1747           op.getIndexingMapMatchingResult(originalResult.value());
1748       SmallVector<ReassociationIndices> reassociation =
1749           getOperandReassociation(indexingMap, collapsingInfo);
1750       Value result;
1751       if (isa<MemRefType>(collapsedOpResult.getType())) {
1752         MemRefType expandShapeResultType = MemRefType::get(
1753             originalResultType.getShape(), originalResultType.getElementType());
1754         result = rewriter.create<memref::ExpandShapeOp>(
1755             loc, expandShapeResultType, collapsedOpResult, reassociation);
1756       } else {
1757         result = rewriter.create<tensor::ExpandShapeOp>(
1758             loc, originalResultType, collapsedOpResult, reassociation);
1759       }
1760       results.push_back(result);
1761     } else {
1762       results.push_back(collapsedOpResult);
1763     }
1764   }
1765   return CollapseResult{results, collapsedOp};
1766 }
1767 
1768 namespace {
1769 
1770 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
1771 /// contracting dimensions of the loop.
1772 class FoldWithProducerReshapeOpByCollapsing
1773     : public OpRewritePattern<GenericOp> {
1774 public:
1775   FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
1776                                         ControlFusionFn foldReshapes,
1777                                         PatternBenefit benefit = 1)
1778       : OpRewritePattern<GenericOp>(context, benefit),
1779         controlFoldingReshapes(std::move(foldReshapes)) {}
1780 
1781   LogicalResult matchAndRewrite(GenericOp genericOp,
1782                                 PatternRewriter &rewriter) const override {
1783     for (OpOperand &opOperand : genericOp->getOpOperands()) {
1784       tensor::ExpandShapeOp reshapeOp =
1785           opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
1786       if (!reshapeOp)
1787         continue;
1788 
1789       SmallVector<ReassociationIndices> collapsableIterationDims =
1790           getCollapsableIterationSpaceDims(genericOp, &opOperand,
1791                                            reshapeOp.getReassociationIndices());
1792       if (collapsableIterationDims.empty() ||
1793           !controlFoldingReshapes(&opOperand)) {
1794         continue;
1795       }
1796 
1797       std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
1798           genericOp, collapsableIterationDims, rewriter);
1799       if (!collapseResult) {
1800         return rewriter.notifyMatchFailure(
1801             genericOp, "failed to do the fusion by collapsing transformation");
1802       }
1803 
1804       rewriter.replaceOp(genericOp, collapseResult->results);
1805       return success();
1806     }
1807     return failure();
1808   }
1809 
1810 private:
1811   ControlFusionFn controlFoldingReshapes;
1812 };
1813 
1814 class FoldPadWithProducerReshapeOpByCollapsing
1815     : public OpRewritePattern<tensor::PadOp> {
1816 public:
1817   FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1818                                            ControlFusionFn foldReshapes,
1819                                            PatternBenefit benefit = 1)
1820       : OpRewritePattern<tensor::PadOp>(context, benefit),
1821         controlFoldingReshapes(std::move(foldReshapes)) {}
1822 
1823   LogicalResult matchAndRewrite(tensor::PadOp padOp,
1824                                 PatternRewriter &rewriter) const override {
1825     tensor::ExpandShapeOp reshapeOp =
1826         padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1827     if (!reshapeOp)
1828       return failure();
1829     if (!reshapeOp->hasOneUse())
1830       return failure();
1831 
1832     if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1833       return rewriter.notifyMatchFailure(padOp,
1834                                          "fusion blocked by control function");
1835     }
1836 
1837     ArrayRef<int64_t> low = padOp.getStaticLow();
1838     ArrayRef<int64_t> high = padOp.getStaticHigh();
1839     SmallVector<ReassociationIndices> reassociations =
1840         reshapeOp.getReassociationIndices();
1841 
1842     for (auto reInd : reassociations) {
1843       if (reInd.size() == 1)
1844         continue;
1845       if (llvm::any_of(reInd, [&](int64_t ind) {
1846             return low[ind] != 0 || high[ind] != 0;
1847           })) {
1848         return failure();
1849       }
1850     }
1851 
1852     SmallVector<OpFoldResult> newLow, newHigh;
1853     RankedTensorType collapsedType = reshapeOp.getSrcType();
1854     RankedTensorType paddedType = padOp.getResultType();
1855     SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1856     SmallVector<OpFoldResult> expandedPaddedSizes(
1857         getMixedValues(reshapeOp.getStaticOutputShape(),
1858                        reshapeOp.getOutputShape(), rewriter));
1859     AffineExpr d0, d1, d2;
1860     bindDims(rewriter.getContext(), d0, d1, d2);
1861     auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
1862     Location loc = reshapeOp->getLoc();
1863     for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1864       OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1865       OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1866       if (reInd.size() == 1) {
1867         collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1868         OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
1869             rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1870         expandedPaddedSizes[reInd[0]] = paddedSize;
1871       }
1872       newLow.push_back(l);
1873       newHigh.push_back(h);
1874     }
1875 
1876     RankedTensorType collapsedPaddedType =
1877         paddedType.clone(collapsedPaddedShape);
1878     auto newPadOp = rewriter.create<tensor::PadOp>(
1879         loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1880         padOp.getConstantPaddingValue(), padOp.getNofold());
1881 
1882     rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1883         padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1884         expandedPaddedSizes);
1885 
1886     return success();
1887   }
1888 
1889 private:
1890   ControlFusionFn controlFoldingReshapes;
1891 };
1892 
1893 /// Pattern to collapse dimensions.
1894 template <typename LinalgType>
1895 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
1896 public:
1897   CollapseLinalgDimensions(MLIRContext *context,
1898                            GetCollapsableDimensionsFn collapseDimensions,
1899                            PatternBenefit benefit = 1)
1900       : OpRewritePattern<LinalgType>(context, benefit),
1901         controlCollapseDimension(std::move(collapseDimensions)) {}
1902 
1903   LogicalResult matchAndRewrite(LinalgType op,
1904                                 PatternRewriter &rewriter) const override {
1905     SmallVector<ReassociationIndices> collapsableIterationDims =
1906         controlCollapseDimension(op);
1907     if (collapsableIterationDims.empty())
1908       return failure();
1909 
1910     // Check if the specified list of dimensions to collapse is a valid list.
1911     if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
1912                                   collapsableIterationDims)) {
1913       return rewriter.notifyMatchFailure(
1914           op, "specified dimensions cannot be collapsed");
1915     }
1916 
1917     std::optional<CollapseResult> collapseResult =
1918         collapseOpIterationDims(op, collapsableIterationDims, rewriter);
1919     if (!collapseResult) {
1920       return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
1921     }
1922     rewriter.replaceOp(op, collapseResult->results);
1923     return success();
1924   }
1925 
1926 private:
1927   GetCollapsableDimensionsFn controlCollapseDimension;
1928 };
1929 
1930 } // namespace
1931 
1932 //===---------------------------------------------------------------------===//
1933 // Methods and patterns that fuse constants with linalg.generic operations.
1934 //===---------------------------------------------------------------------===//
1935 
1936 namespace {
1937 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1938 /// handle cases where the constant is not single-valued.
1939 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
1940 public:
1941   FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
1942       : OpRewritePattern<GenericOp>(context, benefit) {}
1943 
1944   LogicalResult matchAndRewrite(GenericOp genericOp,
1945                                 PatternRewriter &rewriter) const override {
1946     if (!genericOp.hasPureTensorSemantics())
1947       return failure();
1948     for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
1949       Operation *def = opOperand->get().getDefiningOp();
1950       TypedAttr constantAttr;
1951       auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1952         {
1953           DenseElementsAttr splatAttr;
1954           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1955               splatAttr.isSplat() &&
1956               splatAttr.getType().getElementType().isIntOrFloat()) {
1957             constantAttr = splatAttr.getSplatValue<TypedAttr>();
1958             return true;
1959           }
1960         }
1961         {
1962           IntegerAttr intAttr;
1963           if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1964             constantAttr = intAttr;
1965             return true;
1966           }
1967         }
1968         {
1969           FloatAttr floatAttr;
1970           if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1971             constantAttr = floatAttr;
1972             return true;
1973           }
1974         }
1975         return false;
1976       };
1977 
1978       auto resultValue = dyn_cast<OpResult>(opOperand->get());
1979       if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
1980         continue;
1981 
1982       // The operands and the indexing_maps of the fused operation the same as
1983       // the operands and indexing_maps of the generic operations with the
1984       // values at the constant index dropped.
1985       SmallVector<AffineMap> fusedIndexMaps;
1986       SmallVector<Value> fusedOperands;
1987       SmallVector<Location> fusedLocs{genericOp.getLoc()};
1988       fusedIndexMaps.reserve(genericOp->getNumOperands());
1989       fusedOperands.reserve(genericOp.getNumDpsInputs());
1990       fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
1991       for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
1992         if (inputOperand == opOperand)
1993           continue;
1994         Value inputValue = inputOperand->get();
1995         fusedIndexMaps.push_back(
1996             genericOp.getMatchingIndexingMap(inputOperand));
1997         fusedOperands.push_back(inputValue);
1998         fusedLocs.push_back(inputValue.getLoc());
1999       }
2000       for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
2001         fusedIndexMaps.push_back(
2002             genericOp.getMatchingIndexingMap(&outputOperand));
2003 
2004       // Check if the operation shapes to loops map is computable.
2005       if (!inversePermutation(
2006               concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
2007         return rewriter.notifyMatchFailure(
2008             genericOp, "fused op loop bound computation failed");
2009       }
2010 
2011       // Create a constant scalar value from the splat constant.
2012       Value scalarConstant =
2013           rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
2014 
2015       SmallVector<Value> outputOperands = genericOp.getOutputs();
2016       auto fusedOp = rewriter.create<GenericOp>(
2017           rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
2018           /*inputs=*/fusedOperands,
2019           /*outputs=*/outputOperands,
2020           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2021           genericOp.getIteratorTypes(),
2022           /*doc=*/nullptr,
2023           /*library_call=*/nullptr);
2024 
2025       // Map the block argument corresponding to the replaced argument with the
2026       // scalar constant.
2027       Region &region = genericOp->getRegion(0);
2028       Block &entryBlock = *region.begin();
2029       IRMapping mapping;
2030       mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
2031                   scalarConstant);
2032       Region &fusedRegion = fusedOp->getRegion(0);
2033       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
2034                                  mapping);
2035       rewriter.replaceOp(genericOp, fusedOp->getResults());
2036       return success();
2037     }
2038     return failure();
2039   }
2040 };
2041 
2042 } // namespace
2043 
2044 //===---------------------------------------------------------------------===//
2045 // Miscellaneous patterns that help fusion.
2046 //===---------------------------------------------------------------------===//
2047 
2048 namespace {
2049 /// Forces `outs` operands of linalg operations to use `tensor.empty` if the
2050 /// value of the `outs` operand is not used within the op.  This is only
2051 /// implemented for `linalg.generic` operations for now, but should hold for all
2052 /// linalg structured ops.
2053 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
2054   using OpRewritePattern<GenericOp>::OpRewritePattern;
2055 
2056   LogicalResult matchAndRewrite(GenericOp op,
2057                                 PatternRewriter &rewriter) const override {
2058     rewriter.startOpModification(op);
2059     bool modifiedOutput = false;
2060     Location loc = op.getLoc();
2061     for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2062       if (!op.payloadUsesValueFromOperand(&opOperand)) {
2063         Value operandVal = opOperand.get();
2064         auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
2065         if (!operandType)
2066           continue;
2067 
2068         // If outs is sparse, leave it to the sparsifier.
2069         if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
2070           continue;
2071 
2072         // If outs is already an `empty` operation, nothing to do.
2073         auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
2074         if (definingOp)
2075           continue;
2076         modifiedOutput = true;
2077         SmallVector<OpFoldResult> mixedSizes =
2078             tensor::getMixedSizes(rewriter, loc, operandVal);
2079         Value emptyTensor = rewriter.create<tensor::EmptyOp>(
2080             loc, mixedSizes, operandType.getElementType());
2081         op->setOperand(opOperand.getOperandNumber(), emptyTensor);
2082       }
2083     }
2084     if (!modifiedOutput) {
2085       rewriter.cancelOpModification(op);
2086       return failure();
2087     }
2088     rewriter.finalizeOpModification(op);
2089     return success();
2090   }
2091 };
2092 
2093 /// Fold linalg.fill into linalg.generic
2094 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
2095   using OpRewritePattern<GenericOp>::OpRewritePattern;
2096 
2097   LogicalResult matchAndRewrite(GenericOp genericOp,
2098                                 PatternRewriter &rewriter) const override {
2099     if (!genericOp.hasPureTensorSemantics())
2100       return failure();
2101     bool fillFound = false;
2102     Block &payload = genericOp.getRegion().front();
2103     for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
2104       if (!genericOp.payloadUsesValueFromOperand(opOperand))
2105         continue;
2106       FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
2107       if (!fillOp)
2108         continue;
2109       fillFound = true;
2110       Value fillVal = fillOp.value();
2111       auto resultType =
2112           cast<RankedTensorType>(fillOp.result().getType()).getElementType();
2113       Value convertedVal =
2114           convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
2115                                /*isUnsignedCast =*/false);
2116       rewriter.replaceAllUsesWith(
2117           payload.getArgument(opOperand->getOperandNumber()), convertedVal);
2118     }
2119     return success(fillFound);
2120   }
2121 };
2122 } // namespace
2123 
2124 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
2125     RewritePatternSet &patterns,
2126     const ControlFusionFn &controlFoldingReshapes) {
2127   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2128                                                     controlFoldingReshapes);
2129   patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2130                                                         controlFoldingReshapes);
2131   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
2132                                                      controlFoldingReshapes);
2133 }
2134 
2135 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
2136     RewritePatternSet &patterns,
2137     const ControlFusionFn &controlFoldingReshapes) {
2138   patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
2139                                                       controlFoldingReshapes);
2140   patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2141       patterns.getContext(), controlFoldingReshapes);
2142 }
2143 
2144 void mlir::linalg::populateElementwiseOpsFusionPatterns(
2145     RewritePatternSet &patterns,
2146     const ControlFusionFn &controlElementwiseOpsFusion) {
2147   auto *context = patterns.getContext();
2148   patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
2149   patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
2150                RemoveOutsDependency>(context);
2151   // Add the patterns that clean up dead operands and results.
2152   populateEraseUnusedOperandsAndResultsPatterns(patterns);
2153 }
2154 
2155 void mlir::linalg::populateCollapseDimensions(
2156     RewritePatternSet &patterns,
2157     const GetCollapsableDimensionsFn &controlCollapseDimensions) {
2158   patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
2159                CollapseLinalgDimensions<linalg::CopyOp>>(
2160       patterns.getContext(), controlCollapseDimensions);
2161 }
2162 
2163 //===---------------------------------------------------------------------===//
2164 // Passes
2165 //===---------------------------------------------------------------------===//
2166 
2167 namespace {
2168 
2169 /// Pass that fuses generic ops on tensors. Used only for testing.
2170 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
2171 // patterns added here heavily depends on the cost function used. Having an
2172 // opinionated pass of this form is not recommended. Deprecate this pass in
2173 // favor of test passes that check the functionality of each of the patterns
2174 // added here individually.
2175 struct LinalgElementwiseOpFusionPass
2176     : public impl::LinalgElementwiseOpFusionPassBase<
2177           LinalgElementwiseOpFusionPass> {
2178   using impl::LinalgElementwiseOpFusionPassBase<
2179       LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
2180   void runOnOperation() override {
2181     Operation *op = getOperation();
2182     MLIRContext *context = op->getContext();
2183     RewritePatternSet patterns(context);
2184 
2185     // Add folding with reshape by expansion patterns.
2186     ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2187       Operation *producer = fusedOperand->get().getDefiningOp();
2188       return producer && producer->hasOneUse();
2189     };
2190 
2191     // Add elementwise op fusion patterns.
2192     populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
2193     populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
2194     tensor::populateBubbleUpExpandShapePatterns(patterns);
2195 
2196     // General canonicalization patterns.
2197     affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
2198     GenericOp::getCanonicalizationPatterns(patterns, context);
2199     tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
2200     tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
2201     context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
2202         patterns);
2203 
2204     // Add constant folding patterns.
2205     populateConstantFoldLinalgOperations(patterns, defaultControlFn);
2206 
2207     // Use TopDownTraversal for compile time reasons
2208     GreedyRewriteConfig grc;
2209     grc.useTopDownTraversal = true;
2210     (void)applyPatternsGreedily(op, std::move(patterns), grc);
2211   }
2212 };
2213 
2214 } // namespace
2215