xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (revision 1f5335c1db5d54b4465677c224b48e0ffc78e6d9)
15994201cSMaheshRavishankar //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
25994201cSMaheshRavishankar //
35994201cSMaheshRavishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45994201cSMaheshRavishankar // See https://llvm.org/LICENSE.txt for license information.
55994201cSMaheshRavishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65994201cSMaheshRavishankar //
75994201cSMaheshRavishankar //===----------------------------------------------------------------------===//
85994201cSMaheshRavishankar //
95994201cSMaheshRavishankar // This file implements the linalg dialect Fusion on tensors operations pass.
105994201cSMaheshRavishankar //
115994201cSMaheshRavishankar //===----------------------------------------------------------------------===//
121fc096afSMehdi Amini 
1367d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h"
1467d0d7acSMichele Scuttari 
155994201cSMaheshRavishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
16*1f5335c1SMaheshRavishankar #include "mlir/Dialect/Arith/IR/Arith.h"
17abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h"
18b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
195994201cSMaheshRavishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20515c6170SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
21a95ad2daSIan Wood #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
225994201cSMaheshRavishankar #include "mlir/IR/AffineExpr.h"
235994201cSMaheshRavishankar #include "mlir/IR/AffineMap.h"
245994201cSMaheshRavishankar #include "mlir/IR/Matchers.h"
255994201cSMaheshRavishankar #include "mlir/IR/PatternMatch.h"
265994201cSMaheshRavishankar #include "mlir/Support/LLVM.h"
275994201cSMaheshRavishankar #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28a1fe1f5fSKazu Hirata #include <optional>
2969011a2aSMahesh Ravishankar #include <utility>
3067d0d7acSMichele Scuttari 
3167d0d7acSMichele Scuttari namespace mlir {
321e98d488SQuinn Dawkins #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
3367d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h.inc"
3467d0d7acSMichele Scuttari } // namespace mlir
355994201cSMaheshRavishankar 
365994201cSMaheshRavishankar using namespace mlir;
375994201cSMaheshRavishankar using namespace mlir::linalg;
385994201cSMaheshRavishankar 
3932288d37SMahesh Ravishankar //===---------------------------------------------------------------------===//
4032288d37SMahesh Ravishankar // Methods and patterns that fuse elementwise `linalg.generic` operations.
4132288d37SMahesh Ravishankar //===---------------------------------------------------------------------===//
4232288d37SMahesh Ravishankar 
43b241226aSStephan Herhut /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
44b241226aSStephan Herhut /// the `producer` to use in the fused operation given the indexing map of the
45b241226aSStephan Herhut /// result of the producer in the consumer.
46b241226aSStephan Herhut static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
47b241226aSStephan Herhut     OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
48b241226aSStephan Herhut     AffineMap fusedConsumerArgIndexMap) {
49b241226aSStephan Herhut   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
50b241226aSStephan Herhut   // from consumer loop -> consumer arg tensor index/producer result tensor
51b241226aSStephan Herhut   // index. The fused loop is same as the consumer loop. For each producer arg
52b241226aSStephan Herhut   // the indexing map to be computed is a map from consumer loop -> producer
53b241226aSStephan Herhut   // arg tensor index.
54b241226aSStephan Herhut   // producerResultIndexMap is a map from producer loop -> tensor index.
55b241226aSStephan Herhut   // Compute the inverse to get map from tensor index -> producer loop.
56b241226aSStephan Herhut   // The inverse is a map from producer result tensor index -> producer loop.
57b241226aSStephan Herhut   AffineMap invProducerResultIndexMap =
58b241226aSStephan Herhut       inversePermutation(producerResultIndexMap);
59b241226aSStephan Herhut   assert(invProducerResultIndexMap &&
60652b39b4SAart Bik          "expected producer result indexing map to be invertible");
61b241226aSStephan Herhut 
62b241226aSStephan Herhut   LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
63b241226aSStephan Herhut   // argMap is a map from producer loop -> producer arg tensor index.
641227b8abSOleg Shyshkov   AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
65b241226aSStephan Herhut 
66b241226aSStephan Herhut   // Compose argMap with invProducerResultIndexMap to get a map from
67b241226aSStephan Herhut   // producer result tensor index -> producer arg tensor index.
68b241226aSStephan Herhut   AffineMap t1 = argMap.compose(invProducerResultIndexMap);
69b241226aSStephan Herhut 
70b241226aSStephan Herhut   // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
71b241226aSStephan Herhut   // consumer loop/ fused loop -> producer arg tensor index.
72b241226aSStephan Herhut   return t1.compose(fusedConsumerArgIndexMap);
73b241226aSStephan Herhut }
74b241226aSStephan Herhut 
754a4b233fSDanielLevi6 // Checks if the given operand can be dropped, and the remaining operands
764a4b233fSDanielLevi6 // of the fused producer & consumer after the fusion can still compute the
774a4b233fSDanielLevi6 // bounds of the op.
784a4b233fSDanielLevi6 static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
794a4b233fSDanielLevi6     GenericOp producer, GenericOp consumer,
804a4b233fSDanielLevi6     ArrayRef<OpOperand *> opOperandsToIgnore) {
814a4b233fSDanielLevi6   SmallVector<AffineMap> indexingMaps;
824a4b233fSDanielLevi6 
834a4b233fSDanielLevi6   SmallVector<GenericOp> ops = {producer, consumer};
844a4b233fSDanielLevi6   for (auto &op : ops) {
854a4b233fSDanielLevi6     for (auto &opOperand : op->getOpOperands()) {
864a4b233fSDanielLevi6       if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
874a4b233fSDanielLevi6         continue;
884a4b233fSDanielLevi6       }
894a4b233fSDanielLevi6       indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
904a4b233fSDanielLevi6     }
914a4b233fSDanielLevi6   }
9206514c55SIan Wood   if (indexingMaps.empty()) {
9306514c55SIan Wood     // If there are no indexing maps, the operand can only be dropped
9406514c55SIan Wood     // if neither op has loops.
9506514c55SIan Wood     return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
9606514c55SIan Wood   }
974a4b233fSDanielLevi6 
984a4b233fSDanielLevi6   // The concatanation of the remained indexing maps must be invertible, so
994a4b233fSDanielLevi6   // the bounds of the op can be still computed after dropping the selected
1004a4b233fSDanielLevi6   // operand. inversePermutation returns an empty AffineMap in case the
1014a4b233fSDanielLevi6   // concatanated indexing maps are not invertible.
10206514c55SIan Wood   return inversePermutation(concatAffineMaps(
10306514c55SIan Wood              indexingMaps, producer.getContext())) != AffineMap();
1044a4b233fSDanielLevi6 }
1054a4b233fSDanielLevi6 
106cf2d625aSAmir Bishara /// Returns a set of indices of the producer's results which would
107cf2d625aSAmir Bishara /// be preserved after the fusion.
1084a4b233fSDanielLevi6 /// * There is a chance that the implementation of the transformation does not
1094a4b233fSDanielLevi6 /// agree with the result of this method. This function gives a prediction based
1104a4b233fSDanielLevi6 /// on an optimized fusion.
1114a4b233fSDanielLevi6 llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
1124a4b233fSDanielLevi6     GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
113cf2d625aSAmir Bishara   llvm::SmallDenseSet<int> preservedProducerResults;
1144a4b233fSDanielLevi6   llvm::SmallVector<OpOperand *> opOperandsToIgnore;
1154a4b233fSDanielLevi6 
1164a4b233fSDanielLevi6   // The fusedOperand will be removed during the fusion
1174a4b233fSDanielLevi6   opOperandsToIgnore.emplace_back(fusedOperand);
1184a4b233fSDanielLevi6 
119cf2d625aSAmir Bishara   for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
120cf2d625aSAmir Bishara     auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
1214a4b233fSDanielLevi6     opOperandsToIgnore.emplace_back(outputOperand);
122cf2d625aSAmir Bishara     if (producer.payloadUsesValueFromOperand(outputOperand) ||
1234a4b233fSDanielLevi6         !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
1244a4b233fSDanielLevi6                                                   opOperandsToIgnore) ||
125cf2d625aSAmir Bishara         llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
126cf2d625aSAmir Bishara           return user != consumer.getOperation();
127cf2d625aSAmir Bishara         })) {
128cf2d625aSAmir Bishara       preservedProducerResults.insert(producerResult.index());
1294a4b233fSDanielLevi6 
1304a4b233fSDanielLevi6       // In case the operand can't be dropped
131d8b6df2eSJie Fu       (void)opOperandsToIgnore.pop_back_val();
132cf2d625aSAmir Bishara     }
133cf2d625aSAmir Bishara   }
134cf2d625aSAmir Bishara   return preservedProducerResults;
135cf2d625aSAmir Bishara }
136cf2d625aSAmir Bishara 
1375994201cSMaheshRavishankar /// Conditions for elementwise fusion of generic operations.
138a7bfdc23SMahesh Ravishankar bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
13969011a2aSMahesh Ravishankar   if (!fusedOperand)
14069011a2aSMahesh Ravishankar     return false;
14169011a2aSMahesh Ravishankar 
142a7bfdc23SMahesh Ravishankar   auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
143a7bfdc23SMahesh Ravishankar   auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
144a7bfdc23SMahesh Ravishankar 
145a7bfdc23SMahesh Ravishankar   // Check producer and consumer are generic ops.
146a7bfdc23SMahesh Ravishankar   if (!producer || !consumer)
147a7bfdc23SMahesh Ravishankar     return false;
148a7bfdc23SMahesh Ravishankar 
149e3f75c1cSIvan Butygin   // Consumer can have mixed semantics, just check operand itself has tensor
150e3f75c1cSIvan Butygin   // type. Producer must have full tensor semantics to avoid potential
151e3f75c1cSIvan Butygin   // aliasing between producer and consumer memrefs.
1520a8e3dd4SMatthias Springer   if (!producer.hasPureTensorSemantics() ||
1535550c821STres Popp       !isa<RankedTensorType>(fusedOperand->get().getType()))
1545994201cSMaheshRavishankar     return false;
1555994201cSMaheshRavishankar 
1565994201cSMaheshRavishankar   // Verify that
1575994201cSMaheshRavishankar   // - the producer has all "parallel" iterator type.
1585994201cSMaheshRavishankar   if (producer.getNumParallelLoops() != producer.getNumLoops())
1595994201cSMaheshRavishankar     return false;
1605994201cSMaheshRavishankar 
1615994201cSMaheshRavishankar   // Only allow fusing the producer of an input operand for now.
1625994201cSMaheshRavishankar   // TODO: allow fusing the producer of an output operand.
163b4db15a9SAlexander Belyaev   if (!consumer.isDpsInput(fusedOperand))
1645994201cSMaheshRavishankar     return false;
1655994201cSMaheshRavishankar 
1665994201cSMaheshRavishankar   // Get the consumer index map. The number of results of the consumer index
1675994201cSMaheshRavishankar   // map must match the number of loops of the producer.
1681227b8abSOleg Shyshkov   AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
1695994201cSMaheshRavishankar   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
1705994201cSMaheshRavishankar     return false;
1715994201cSMaheshRavishankar 
1725994201cSMaheshRavishankar   // Finally the index_map for the result must be invertible. For now just
1735994201cSMaheshRavishankar   // verify it is a permutation.
1745994201cSMaheshRavishankar   AffineMap producerResultIndexMap =
175b4db15a9SAlexander Belyaev       producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
176b241226aSStephan Herhut   if (!producerResultIndexMap.isPermutation())
177b241226aSStephan Herhut     return false;
178b241226aSStephan Herhut 
179b241226aSStephan Herhut   // Ensure that the fusion does not remove size information required to
180b241226aSStephan Herhut   // get the loop bounds. For non-reduction generics, this is trivially the
181b241226aSStephan Herhut   // case due to the output operand. For reductions, we need to check that after
182b241226aSStephan Herhut   // the fusion, each loop dimension has at least one input that defines it.
183b241226aSStephan Herhut   if ((consumer.getNumReductionLoops())) {
184d10d49dcSRiver Riddle     BitVector coveredDims(consumer.getNumLoops(), false);
185b241226aSStephan Herhut 
186b241226aSStephan Herhut     auto addToCoveredDims = [&](AffineMap map) {
187b241226aSStephan Herhut       for (auto result : map.getResults())
1881609f1c2Slong.chen         if (auto dimExpr = dyn_cast<AffineDimExpr>(result))
189b241226aSStephan Herhut           coveredDims[dimExpr.getPosition()] = true;
190b241226aSStephan Herhut     };
191b241226aSStephan Herhut 
192b241226aSStephan Herhut     for (auto pair :
193d2c0572bSJacques Pienaar          llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
194b241226aSStephan Herhut       Value operand = std::get<0>(pair);
195a7bfdc23SMahesh Ravishankar       if (operand == fusedOperand->get())
196b241226aSStephan Herhut         continue;
197b241226aSStephan Herhut       AffineMap operandMap = std::get<1>(pair);
198b241226aSStephan Herhut       addToCoveredDims(operandMap);
1995994201cSMaheshRavishankar     }
2005994201cSMaheshRavishankar 
201b4db15a9SAlexander Belyaev     for (OpOperand *operand : producer.getDpsInputOperands()) {
202b241226aSStephan Herhut       AffineMap newIndexingMap =
203b241226aSStephan Herhut           getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
204b241226aSStephan Herhut               operand, producerResultIndexMap, consumerIndexMap);
205b241226aSStephan Herhut       addToCoveredDims(newIndexingMap);
206b241226aSStephan Herhut     }
207b241226aSStephan Herhut     if (!coveredDims.all())
208b241226aSStephan Herhut       return false;
209b241226aSStephan Herhut   }
2105994201cSMaheshRavishankar 
211b241226aSStephan Herhut   return true;
2125994201cSMaheshRavishankar }
2135994201cSMaheshRavishankar 
2145994201cSMaheshRavishankar /// Generate the region of the fused tensor operation. The region of the fused
2155994201cSMaheshRavishankar /// op must be empty.
2162d4b9986SMahesh Ravishankar static void generateFusedElementwiseOpRegion(
2172d4b9986SMahesh Ravishankar     RewriterBase &rewriter, GenericOp fusedOp,
2182d4b9986SMahesh Ravishankar     AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
2192d4b9986SMahesh Ravishankar     unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
220a7bfdc23SMahesh Ravishankar   auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
221a7bfdc23SMahesh Ravishankar   auto consumer = cast<GenericOp>(fusedOperand->getOwner());
2225994201cSMaheshRavishankar   // Build the region of the fused op.
2235994201cSMaheshRavishankar   Block &producerBlock = producer->getRegion(0).front();
2245994201cSMaheshRavishankar   Block &consumerBlock = consumer->getRegion(0).front();
2255994201cSMaheshRavishankar   OpBuilder::InsertionGuard guard(rewriter);
22691d5653eSMatthias Springer   Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
22791d5653eSMatthias Springer   IRMapping mapper;
2285994201cSMaheshRavishankar 
2295994201cSMaheshRavishankar   // 2. Add an index operation for every fused loop dimension and use the
2305994201cSMaheshRavishankar   // `consumerToProducerLoopsMap` to map the producer indices.
2315994201cSMaheshRavishankar   if (producer.hasIndexSemantics()) {
2325994201cSMaheshRavishankar     // Add an index operation for every fused loop dimension.
2335994201cSMaheshRavishankar     unsigned numFusedOpLoops =
2345994201cSMaheshRavishankar         std::max(producer.getNumLoops(), consumer.getNumLoops());
2355994201cSMaheshRavishankar     SmallVector<Value> fusedIndices;
2365994201cSMaheshRavishankar     fusedIndices.reserve(numFusedOpLoops);
2375994201cSMaheshRavishankar     llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
2385994201cSMaheshRavishankar                     std::back_inserter(fusedIndices), [&](uint64_t dim) {
2395994201cSMaheshRavishankar                       return rewriter.create<IndexOp>(producer.getLoc(), dim);
2405994201cSMaheshRavishankar                     });
2415994201cSMaheshRavishankar     for (IndexOp indexOp :
2425994201cSMaheshRavishankar          llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
2434c48f016SMatthias Springer       Value newIndex = rewriter.create<affine::AffineApplyOp>(
2445994201cSMaheshRavishankar           producer.getLoc(),
245d3b3f765SJacques Pienaar           consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices);
2465994201cSMaheshRavishankar       mapper.map(indexOp.getResult(), newIndex);
2475994201cSMaheshRavishankar     }
2485994201cSMaheshRavishankar   }
2495994201cSMaheshRavishankar   // TODO: allow fusing the producer of an output operand.
250b4db15a9SAlexander Belyaev   assert(consumer.isDpsInput(fusedOperand) &&
2515994201cSMaheshRavishankar          "expected producer of input operand");
2525994201cSMaheshRavishankar   // 3. Consumer input operands up to consumerIdx (exclusive).
2535994201cSMaheshRavishankar   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
254a7bfdc23SMahesh Ravishankar            fusedOperand->getOperandNumber())) // input assumption.
255e084679fSRiver Riddle     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
2565994201cSMaheshRavishankar 
2575994201cSMaheshRavishankar   // Replacing consumerIdx requires getting the cloned, yielded, value from
2585994201cSMaheshRavishankar   // the (cloned) producer block. This happens in step 9.
2595994201cSMaheshRavishankar 
2605994201cSMaheshRavishankar   // 4. Splice in producer's input operands.
2615994201cSMaheshRavishankar   for (BlockArgument bbArg :
262b4db15a9SAlexander Belyaev        producerBlock.getArguments().take_front(producer.getNumDpsInputs()))
263e084679fSRiver Riddle     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
2645994201cSMaheshRavishankar 
2655994201cSMaheshRavishankar   // 5. Remaining consumer's input operands (drop past index `consumerIdx`).
2665994201cSMaheshRavishankar   for (BlockArgument bbArg :
2675994201cSMaheshRavishankar        consumerBlock.getArguments()
268b4db15a9SAlexander Belyaev            .take_front(consumer.getNumDpsInputs())
269a7bfdc23SMahesh Ravishankar            .drop_front(fusedOperand->getOperandNumber() + 1))
270e084679fSRiver Riddle     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
271a7bfdc23SMahesh Ravishankar 
272a7bfdc23SMahesh Ravishankar   // 6. All of the producer's output operands
273215666d9SAdrian Kuegel   for (const auto &bbArg : llvm::enumerate(
2742d4b9986SMahesh Ravishankar            producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
2752d4b9986SMahesh Ravishankar     if (!preservedProducerResults.count(bbArg.index()))
2762d4b9986SMahesh Ravishankar       continue;
2772d4b9986SMahesh Ravishankar     mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
2782d4b9986SMahesh Ravishankar                                                       bbArg.value().getLoc()));
2792d4b9986SMahesh Ravishankar   }
280a7bfdc23SMahesh Ravishankar 
281a7bfdc23SMahesh Ravishankar   // 7. All of consumer's output operands.
2825994201cSMaheshRavishankar   for (BlockArgument bbArg :
283b4db15a9SAlexander Belyaev        consumerBlock.getArguments().take_back(consumer.getNumDpsInits()))
284e084679fSRiver Riddle     mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
2855994201cSMaheshRavishankar 
2865994201cSMaheshRavishankar   // 8. Clone all producer operations except for the yield and index operations
2875994201cSMaheshRavishankar   // to the fused operation.
2885994201cSMaheshRavishankar   for (auto &op : producerBlock.without_terminator()) {
2895994201cSMaheshRavishankar     if (!isa<IndexOp>(op))
2905994201cSMaheshRavishankar       rewriter.clone(op, mapper);
2915994201cSMaheshRavishankar   }
2925994201cSMaheshRavishankar   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
2935994201cSMaheshRavishankar   // forward the yield operand.
294a7bfdc23SMahesh Ravishankar   auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
295a7bfdc23SMahesh Ravishankar   unsigned producerResultNumber =
2965550c821STres Popp       cast<OpResult>(fusedOperand->get()).getResultNumber();
2975994201cSMaheshRavishankar   Value replacement =
298a7bfdc23SMahesh Ravishankar       mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
299a7bfdc23SMahesh Ravishankar 
3005994201cSMaheshRavishankar   // Sanity checks, if replacement is not already in the mapper then it must be
3015994201cSMaheshRavishankar   // produced outside.
302a7bfdc23SMahesh Ravishankar   if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
3035550c821STres Popp     if (auto bb = dyn_cast<BlockArgument>(replacement))
3045994201cSMaheshRavishankar       assert(bb.getOwner() != &producerBlock &&
3055994201cSMaheshRavishankar              "yielded block argument must have been mapped");
3065994201cSMaheshRavishankar     else
3075994201cSMaheshRavishankar       assert(!producer->isAncestor(replacement.getDefiningOp()) &&
3085994201cSMaheshRavishankar              "yielded value must have been mapped");
3095994201cSMaheshRavishankar   }
310a7bfdc23SMahesh Ravishankar   mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()),
3115994201cSMaheshRavishankar              replacement);
3125994201cSMaheshRavishankar   // 10. Clone operations from the consumer to the fused op.
313a7bfdc23SMahesh Ravishankar   for (auto &op : consumerBlock.without_terminator())
3145994201cSMaheshRavishankar     rewriter.clone(op, mapper);
3155994201cSMaheshRavishankar 
316a7bfdc23SMahesh Ravishankar   // 11. Include the final yield (which is the remapped values for all the
317a7bfdc23SMahesh Ravishankar   // yield)
318a7bfdc23SMahesh Ravishankar   auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator());
319a7bfdc23SMahesh Ravishankar   SmallVector<Value> fusedYieldValues;
320a7bfdc23SMahesh Ravishankar   fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
321a7bfdc23SMahesh Ravishankar                            consumerYieldOp.getNumOperands());
322215666d9SAdrian Kuegel   for (const auto &producerYieldVal :
323215666d9SAdrian Kuegel        llvm::enumerate(producerYieldOp.getOperands())) {
3242d4b9986SMahesh Ravishankar     if (preservedProducerResults.count(producerYieldVal.index()))
3252d4b9986SMahesh Ravishankar       fusedYieldValues.push_back(
3262d4b9986SMahesh Ravishankar           mapper.lookupOrDefault(producerYieldVal.value()));
3272d4b9986SMahesh Ravishankar   }
328a7bfdc23SMahesh Ravishankar   for (auto consumerYieldVal : consumerYieldOp.getOperands())
329a7bfdc23SMahesh Ravishankar     fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
330a7bfdc23SMahesh Ravishankar   rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
331a7bfdc23SMahesh Ravishankar 
3325994201cSMaheshRavishankar   // Sanity checks.
3335994201cSMaheshRavishankar   assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
3345994201cSMaheshRavishankar          "Ill-formed GenericOp region");
3355994201cSMaheshRavishankar }
3365994201cSMaheshRavishankar 
33769011a2aSMahesh Ravishankar FailureOr<mlir::linalg::ElementwiseOpFusionResult>
338a7bfdc23SMahesh Ravishankar mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
339a7bfdc23SMahesh Ravishankar                                  OpOperand *fusedOperand) {
340a7bfdc23SMahesh Ravishankar   assert(areElementwiseOpsFusable(fusedOperand) &&
341a7bfdc23SMahesh Ravishankar          "expected elementwise operation pre-conditions to pass");
3425550c821STres Popp   auto producerResult = cast<OpResult>(fusedOperand->get());
343a7bfdc23SMahesh Ravishankar   auto producer = cast<GenericOp>(producerResult.getOwner());
344a7bfdc23SMahesh Ravishankar   auto consumer = cast<GenericOp>(fusedOperand->getOwner());
3455994201cSMaheshRavishankar   // TODO: allow fusing the producer of an output operand.
346b4db15a9SAlexander Belyaev   assert(consumer.isDpsInput(fusedOperand) &&
3475994201cSMaheshRavishankar          "expected producer of input operand");
3484a4b233fSDanielLevi6   /// Find the results of the producer that have uses outside of the consumer,
3494a4b233fSDanielLevi6   /// after the fusion.
350cf2d625aSAmir Bishara   llvm::SmallDenseSet<int> preservedProducerResults =
3514a4b233fSDanielLevi6       mlir::linalg::getPreservedProducerResults(producer, consumer,
3524a4b233fSDanielLevi6                                                 fusedOperand);
3535994201cSMaheshRavishankar 
3545994201cSMaheshRavishankar   // Compute the fused operands list and indexing maps.
355a7bfdc23SMahesh Ravishankar   SmallVector<Value> fusedInputOperands, fusedOutputOperands;
356a7bfdc23SMahesh Ravishankar   SmallVector<Type> fusedResultTypes;
3575994201cSMaheshRavishankar   SmallVector<AffineMap> fusedIndexMaps;
358b4db15a9SAlexander Belyaev   fusedInputOperands.reserve(producer.getNumDpsInputs() +
359b4db15a9SAlexander Belyaev                              consumer.getNumDpsInputs());
3602d4b9986SMahesh Ravishankar   fusedOutputOperands.reserve(preservedProducerResults.size() +
361b4db15a9SAlexander Belyaev                               consumer.getNumDpsInits());
3622d4b9986SMahesh Ravishankar   fusedResultTypes.reserve(preservedProducerResults.size() +
363b4db15a9SAlexander Belyaev                            consumer.getNumDpsInits());
364a7cccb9cSAlexander Belyaev   fusedIndexMaps.reserve(producer->getNumOperands() +
365a7cccb9cSAlexander Belyaev                          consumer->getNumOperands());
3665994201cSMaheshRavishankar   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
3675994201cSMaheshRavishankar   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
368b4db15a9SAlexander Belyaev   auto consumerInputs = consumer.getDpsInputOperands();
369a7cccb9cSAlexander Belyaev   auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
370a7cccb9cSAlexander Belyaev     return operand == fusedOperand;
371a7cccb9cSAlexander Belyaev   });
3725994201cSMaheshRavishankar   assert(it != consumerInputs.end() && "expected to find the consumer operand");
3735994201cSMaheshRavishankar   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
374a7bfdc23SMahesh Ravishankar     fusedInputOperands.push_back(opOperand->get());
3751227b8abSOleg Shyshkov     fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
3765994201cSMaheshRavishankar   }
3775994201cSMaheshRavishankar   // 4. Splice in producer's input operands/maps.
3785994201cSMaheshRavishankar   AffineMap producerResultIndexMap =
3791227b8abSOleg Shyshkov       producer.getIndexingMapMatchingResult(producerResult);
380b4db15a9SAlexander Belyaev   for (OpOperand *opOperand : producer.getDpsInputOperands()) {
381a7bfdc23SMahesh Ravishankar     fusedInputOperands.push_back(opOperand->get());
3825994201cSMaheshRavishankar     // Compute indexing maps for the producer args in the fused operation.
3835994201cSMaheshRavishankar     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
3845994201cSMaheshRavishankar         opOperand, producerResultIndexMap,
3851227b8abSOleg Shyshkov         consumer.getMatchingIndexingMap(fusedOperand));
3865994201cSMaheshRavishankar     fusedIndexMaps.push_back(map);
3875994201cSMaheshRavishankar   }
3885994201cSMaheshRavishankar   // 5. Remaining consumer's input operands/maps (drop past index
3895994201cSMaheshRavishankar   // `consumerIdx`).
3905994201cSMaheshRavishankar   for (OpOperand *opOperand :
3915994201cSMaheshRavishankar        llvm::make_range(std::next(it), consumerInputs.end())) {
392a7bfdc23SMahesh Ravishankar     fusedInputOperands.push_back(opOperand->get());
3931227b8abSOleg Shyshkov     fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
3945994201cSMaheshRavishankar   }
395a7bfdc23SMahesh Ravishankar 
396a7bfdc23SMahesh Ravishankar   // 6. Collect all of the producer outputs.
3970b2197b0SMatthias Springer   for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
3982d4b9986SMahesh Ravishankar     if (!preservedProducerResults.count(opOperand.index()))
3992d4b9986SMahesh Ravishankar       continue;
4002d4b9986SMahesh Ravishankar 
4010b2197b0SMatthias Springer     fusedOutputOperands.push_back(opOperand.value().get());
402a7bfdc23SMahesh Ravishankar     AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
4030b2197b0SMatthias Springer         &opOperand.value(), producerResultIndexMap,
4041227b8abSOleg Shyshkov         consumer.getMatchingIndexingMap(fusedOperand));
405a7bfdc23SMahesh Ravishankar     fusedIndexMaps.push_back(map);
4060b2197b0SMatthias Springer     fusedResultTypes.push_back(opOperand.value().get().getType());
407a7bfdc23SMahesh Ravishankar   }
408a7bfdc23SMahesh Ravishankar 
409a7bfdc23SMahesh Ravishankar   // 7. All of consumer's output operands (skip operands: added by the builder).
4100b2197b0SMatthias Springer   for (OpOperand &opOperand : consumer.getDpsInitsMutable()) {
4110b2197b0SMatthias Springer     fusedOutputOperands.push_back(opOperand.get());
4120b2197b0SMatthias Springer     fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
4130b2197b0SMatthias Springer     Type resultType = opOperand.get().getType();
4145550c821STres Popp     if (!isa<MemRefType>(resultType))
415e3f75c1cSIvan Butygin       fusedResultTypes.push_back(resultType);
416a7bfdc23SMahesh Ravishankar   }
4175994201cSMaheshRavishankar 
4185994201cSMaheshRavishankar   // Generate the fused op.
4195994201cSMaheshRavishankar   auto fusedOp = rewriter.create<GenericOp>(
420a7bfdc23SMahesh Ravishankar       consumer.getLoc(), fusedResultTypes, fusedInputOperands,
421a7bfdc23SMahesh Ravishankar       fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
422d3b3f765SJacques Pienaar       consumer.getIteratorTypes(),
4235994201cSMaheshRavishankar       /*doc=*/nullptr,
4245994201cSMaheshRavishankar       /*library_call=*/nullptr);
425a99e06aaSMaheshRavishankar   if (!fusedOp.getShapesToLoopsMap()) {
426a99e06aaSMaheshRavishankar     // Fused op has invalid indexing maps. Typically this means something is off
427a99e06aaSMaheshRavishankar     // in the input, but going ahead here would result in verification errors.
428a99e06aaSMaheshRavishankar     // So cleanup and abort.
429a99e06aaSMaheshRavishankar     rewriter.eraseOp(fusedOp);
430a7bfdc23SMahesh Ravishankar     return rewriter.notifyMatchFailure(
431a7bfdc23SMahesh Ravishankar         fusedOp, "fused op failed loop bound computation check");
432a99e06aaSMaheshRavishankar   }
4335994201cSMaheshRavishankar 
4345994201cSMaheshRavishankar   // Construct an AffineMap from consumer loops to producer loops.
4355994201cSMaheshRavishankar   // consumer loop -> tensor index
4361227b8abSOleg Shyshkov   AffineMap consumerResultIndexMap =
4371227b8abSOleg Shyshkov       consumer.getMatchingIndexingMap(fusedOperand);
4385994201cSMaheshRavishankar   // tensor index -> producer loop
4395994201cSMaheshRavishankar   AffineMap invProducerResultIndexMap =
4405994201cSMaheshRavishankar       inversePermutation(producerResultIndexMap);
4415994201cSMaheshRavishankar   assert(invProducerResultIndexMap &&
4425994201cSMaheshRavishankar          "expected producer result indexig map to be invertible");
4435994201cSMaheshRavishankar   // consumer loop -> producer loop
4445994201cSMaheshRavishankar   AffineMap consumerToProducerLoopsMap =
4455994201cSMaheshRavishankar       invProducerResultIndexMap.compose(consumerResultIndexMap);
4465994201cSMaheshRavishankar 
4472d4b9986SMahesh Ravishankar   generateFusedElementwiseOpRegion(
4482d4b9986SMahesh Ravishankar       rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
4492d4b9986SMahesh Ravishankar       consumer.getNumLoops(), preservedProducerResults);
45069011a2aSMahesh Ravishankar   ElementwiseOpFusionResult result;
45169011a2aSMahesh Ravishankar   result.fusedOp = fusedOp;
45269011a2aSMahesh Ravishankar   int resultNum = 0;
45369011a2aSMahesh Ravishankar   for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
45469011a2aSMahesh Ravishankar     if (preservedProducerResults.count(index))
45569011a2aSMahesh Ravishankar       result.replacements[producerResult] = fusedOp->getResult(resultNum++);
45669011a2aSMahesh Ravishankar   for (auto consumerResult : consumer->getResults())
45769011a2aSMahesh Ravishankar     result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
45869011a2aSMahesh Ravishankar   return result;
45932288d37SMahesh Ravishankar }
46032288d37SMahesh Ravishankar 
46132288d37SMahesh Ravishankar namespace {
46232288d37SMahesh Ravishankar /// Patterns to fuse a generic op, with the producer of its operands.
46332288d37SMahesh Ravishankar class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
46432288d37SMahesh Ravishankar public:
4652291705dSMahesh Ravishankar   FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
46632288d37SMahesh Ravishankar                      PatternBenefit benefit = 1)
4672291705dSMahesh Ravishankar       : OpRewritePattern<GenericOp>(context, benefit),
4682291705dSMahesh Ravishankar         controlFn(std::move(fun)) {}
46932288d37SMahesh Ravishankar 
47032288d37SMahesh Ravishankar   LogicalResult matchAndRewrite(GenericOp genericOp,
47132288d37SMahesh Ravishankar                                 PatternRewriter &rewriter) const override {
47232288d37SMahesh Ravishankar     // Find the first operand that is defined by another generic op on tensors.
473a7cccb9cSAlexander Belyaev     for (OpOperand &opOperand : genericOp->getOpOperands()) {
474a7cccb9cSAlexander Belyaev       if (!areElementwiseOpsFusable(&opOperand))
47532288d37SMahesh Ravishankar         continue;
476a7cccb9cSAlexander Belyaev       if (!controlFn(&opOperand))
477a7bfdc23SMahesh Ravishankar         continue;
478a7bfdc23SMahesh Ravishankar 
47969011a2aSMahesh Ravishankar       Operation *producer = opOperand.get().getDefiningOp();
4805c03c056SAart Bik 
481986287e7SMatthias Springer       // Find the producer of the operand.
482986287e7SMatthias Springer       FailureOr<ElementwiseOpFusionResult> fusionResult =
483986287e7SMatthias Springer           fuseElementwiseOps(rewriter, &opOperand);
484986287e7SMatthias Springer       if (failed(fusionResult))
485986287e7SMatthias Springer         return rewriter.notifyMatchFailure(genericOp, "fusion failed");
486986287e7SMatthias Springer 
4875c03c056SAart Bik       // Perform the fusion.
48869011a2aSMahesh Ravishankar       for (auto [origVal, replacement] : fusionResult->replacements) {
4892c40a0a6SMatthias Springer         rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
49069011a2aSMahesh Ravishankar           // Only replace consumer uses.
49169011a2aSMahesh Ravishankar           return use.get().getDefiningOp() != producer;
49269011a2aSMahesh Ravishankar         });
49332288d37SMahesh Ravishankar       }
49469011a2aSMahesh Ravishankar       rewriter.eraseOp(genericOp);
49569011a2aSMahesh Ravishankar       return success();
49632288d37SMahesh Ravishankar     }
49732288d37SMahesh Ravishankar     return failure();
49832288d37SMahesh Ravishankar   }
49932288d37SMahesh Ravishankar 
50032288d37SMahesh Ravishankar private:
5012291705dSMahesh Ravishankar   ControlFusionFn controlFn;
50232288d37SMahesh Ravishankar };
50332288d37SMahesh Ravishankar } // namespace
50432288d37SMahesh Ravishankar 
50532288d37SMahesh Ravishankar //===---------------------------------------------------------------------===//
50632288d37SMahesh Ravishankar // Methods and patterns that fuse reshape ops with elementwise operations by
50732288d37SMahesh Ravishankar // expanding the dimensionality of the elementwise operations.
50832288d37SMahesh Ravishankar //===---------------------------------------------------------------------===//
50932288d37SMahesh Ravishankar 
5103f18f6a2SQuinn Dawkins /// Conditions for folding a structured linalg operation with a reshape op by
5113f18f6a2SQuinn Dawkins /// expanding the iteration space dimensionality for tensor operations. These
5123f18f6a2SQuinn Dawkins /// are preconditions assumed by `foldReshapeByDimExpansion` which implements
5133f18f6a2SQuinn Dawkins /// the following fusion pattern.
5145994201cSMaheshRavishankar ///
5155994201cSMaheshRavishankar ///  Consider
5165994201cSMaheshRavishankar ///
5175994201cSMaheshRavishankar ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
5185994201cSMaheshRavishankar ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
5195994201cSMaheshRavishankar ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
5205994201cSMaheshRavishankar ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
521206365bfSAlexander Belyaev ///  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
5225994201cSMaheshRavishankar ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
5235994201cSMaheshRavishankar ///
5243f18f6a2SQuinn Dawkins ///  The reshape can be folded into the `linalgOp` if its loop dimensionality
5251ad9b266Slorenzo chelini ///  is increased to match the result (operand) of the tensor.expand_shape.
5263f18f6a2SQuinn Dawkins ///  The indexing_map of the fused tensor in the `linalgOp` and the
5275994201cSMaheshRavishankar ///  reassociation map helps compute the indexing maps of the modified op.
5285994201cSMaheshRavishankar ///  For the above example, based on the reassociation map it
5295994201cSMaheshRavishankar ///  can be concluded that
5305994201cSMaheshRavishankar ///
5315994201cSMaheshRavishankar ///  - The loop used to access the first dimension of the fused tensor is split
5325994201cSMaheshRavishankar ///    into two.
5335994201cSMaheshRavishankar ///  - The loop used to access the second dimension of the fused tensor is kept
5345994201cSMaheshRavishankar ///    as is.
5355994201cSMaheshRavishankar ///  - The loop used to access the third dimension of the fused tensor is split
5365994201cSMaheshRavishankar ///    into three.
5375994201cSMaheshRavishankar ///
5385994201cSMaheshRavishankar ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
5395994201cSMaheshRavishankar ///  op, then
5405994201cSMaheshRavishankar ///
5415994201cSMaheshRavishankar ///   d0 -> e0, e1
5425994201cSMaheshRavishankar ///   d1 -> e2, e3, e4
5435994201cSMaheshRavishankar ///   d2 -> e5
5445994201cSMaheshRavishankar ///
5453f18f6a2SQuinn Dawkins ///  substituting this, the structured op can be rewritten as
5465994201cSMaheshRavishankar ///
5475994201cSMaheshRavishankar ///  %d = linalg.generic ins(%0, %1 : )
5485994201cSMaheshRavishankar ///        indexing_maps =
5495994201cSMaheshRavishankar ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
5505994201cSMaheshRavishankar ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
5515994201cSMaheshRavishankar ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
5525994201cSMaheshRavishankar ///
5535994201cSMaheshRavishankar ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
5545994201cSMaheshRavishankar ///  to make it consistent
5555994201cSMaheshRavishankar ///
556206365bfSAlexander Belyaev ///  %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
5575994201cSMaheshRavishankar ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
558206365bfSAlexander Belyaev ///  %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
5595994201cSMaheshRavishankar ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
5605994201cSMaheshRavishankar ///
5615994201cSMaheshRavishankar ///  The added reshapes are again expanding patterns, so they will get fused
5625994201cSMaheshRavishankar ///  with its producers if possible.
5633f18f6a2SQuinn Dawkins static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
5645994201cSMaheshRavishankar                                                OpOperand *fusableOpOperand) {
5655994201cSMaheshRavishankar   // Is fusable only if:
5665994201cSMaheshRavishankar   // - All the indexing maps for operands and results are projected
5675994201cSMaheshRavishankar   //   permutations.
5685994201cSMaheshRavishankar   // - The fused tensor is not a scalar.
5693f18f6a2SQuinn Dawkins   // - All the loops for the reshaped operand are parallel loops.
5703f18f6a2SQuinn Dawkins   SmallVector<utils::IteratorType> iteratorTypes =
5713f18f6a2SQuinn Dawkins       linalgOp.getIteratorTypesArray();
5723f18f6a2SQuinn Dawkins   AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
5733f18f6a2SQuinn Dawkins   return linalgOp.hasPureTensorSemantics() &&
5743f18f6a2SQuinn Dawkins          llvm::all_of(linalgOp.getIndexingMaps().getValue(),
5755994201cSMaheshRavishankar                       [](Attribute attr) {
5765550c821STres Popp                         return cast<AffineMapAttr>(attr)
5775994201cSMaheshRavishankar                             .getValue()
5785994201cSMaheshRavishankar                             .isProjectedPermutation();
5795994201cSMaheshRavishankar                       }) &&
5803f18f6a2SQuinn Dawkins          operandMap.getNumResults() > 0 &&
5813f18f6a2SQuinn Dawkins          llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
5823f18f6a2SQuinn Dawkins            return isParallelIterator(
5833f18f6a2SQuinn Dawkins                iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
5843f18f6a2SQuinn Dawkins          });
5855994201cSMaheshRavishankar }
5865994201cSMaheshRavishankar 
5875994201cSMaheshRavishankar namespace {
5885994201cSMaheshRavishankar /// Information needed to expand a generic operation to fold the reshape with
5895994201cSMaheshRavishankar /// it.
5905994201cSMaheshRavishankar class ExpansionInfo {
5915994201cSMaheshRavishankar public:
5925994201cSMaheshRavishankar   // Computes the mapping from original dimensions of the op to the dimensions
5935994201cSMaheshRavishankar   // of the expanded op given the `indexingMap` of the fused operand/result of
5945994201cSMaheshRavishankar   // the generic op, the `reassocationMaps` of the reshape op and the shape of
5955994201cSMaheshRavishankar   // the expanded op.
5965994201cSMaheshRavishankar   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
5975994201cSMaheshRavishankar                         ArrayRef<AffineMap> reassociationMaps,
5985994201cSMaheshRavishankar                         ArrayRef<int64_t> expandedShape,
5994317a3dfSMaheshRavishankar                         ArrayRef<int64_t> collapsedShape,
6005994201cSMaheshRavishankar                         PatternRewriter &rewriter);
6015994201cSMaheshRavishankar   unsigned getOrigOpNumDims() const { return reassociation.size(); }
6025994201cSMaheshRavishankar   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
6035994201cSMaheshRavishankar   ReassociationIndicesRef getExpandedDims(unsigned i) const {
6045994201cSMaheshRavishankar     return reassociation[i];
6055994201cSMaheshRavishankar   }
6065994201cSMaheshRavishankar   ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
6075994201cSMaheshRavishankar     return expandedShapeMap[i];
6085994201cSMaheshRavishankar   }
6094317a3dfSMaheshRavishankar   ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
6105994201cSMaheshRavishankar 
6115994201cSMaheshRavishankar private:
6125994201cSMaheshRavishankar   /// Reassociation from the dimensions in the original operation to the
6135994201cSMaheshRavishankar   /// dimension of the expanded operation.
6145994201cSMaheshRavishankar   SmallVector<ReassociationIndices> reassociation;
6155994201cSMaheshRavishankar   /// Mapping from extent of loops in the original operation, to the extent of
6165994201cSMaheshRavishankar   /// loops in the expanded operation.
6175994201cSMaheshRavishankar   SmallVector<SmallVector<int64_t>> expandedShapeMap;
6184317a3dfSMaheshRavishankar   /// Extent of the loop in the original operation.
6194317a3dfSMaheshRavishankar   SmallVector<int64_t> originalLoopExtent;
6205994201cSMaheshRavishankar   unsigned expandedOpNumDims;
6215994201cSMaheshRavishankar };
6225994201cSMaheshRavishankar } // namespace
6235994201cSMaheshRavishankar 
6245994201cSMaheshRavishankar LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
6255994201cSMaheshRavishankar                                      OpOperand *fusableOpOperand,
6265994201cSMaheshRavishankar                                      ArrayRef<AffineMap> reassociationMaps,
6275994201cSMaheshRavishankar                                      ArrayRef<int64_t> expandedShape,
6284317a3dfSMaheshRavishankar                                      ArrayRef<int64_t> collapsedShape,
6295994201cSMaheshRavishankar                                      PatternRewriter &rewriter) {
6305994201cSMaheshRavishankar   if (reassociationMaps.empty())
6315994201cSMaheshRavishankar     return failure();
6321227b8abSOleg Shyshkov   AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
6335994201cSMaheshRavishankar 
634919e459fSHanhan Wang   SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
635919e459fSHanhan Wang   originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
6365994201cSMaheshRavishankar 
6375994201cSMaheshRavishankar   reassociation.clear();
6385994201cSMaheshRavishankar   expandedShapeMap.clear();
6395994201cSMaheshRavishankar   // Compute the number of dimension in the expanded op that correspond to each
6405994201cSMaheshRavishankar   // dimension of the original op.
6415994201cSMaheshRavishankar   SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
6425994201cSMaheshRavishankar   expandedShapeMap.resize(fusedIndexMap.getNumDims());
643e4853be2SMehdi Amini   for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
6441609f1c2Slong.chen     unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
6455994201cSMaheshRavishankar     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
6465994201cSMaheshRavishankar     numExpandedDims[pos] = foldedDims.getNumResults();
6475994201cSMaheshRavishankar     ArrayRef<int64_t> shape =
6485994201cSMaheshRavishankar         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
6495994201cSMaheshRavishankar     expandedShapeMap[pos].assign(shape.begin(), shape.end());
6505994201cSMaheshRavishankar   }
6515994201cSMaheshRavishankar   // The remaining dimensions remain the same.
6525994201cSMaheshRavishankar   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
6535994201cSMaheshRavishankar     if (expandedShapeMap[i].empty())
6544317a3dfSMaheshRavishankar       expandedShapeMap[i] = {originalLoopExtent[i]};
6555994201cSMaheshRavishankar 
6565994201cSMaheshRavishankar   // Compute reassociation map from the original op to the expanded op.
6575994201cSMaheshRavishankar   unsigned sum = 0;
6585994201cSMaheshRavishankar   reassociation.reserve(fusedIndexMap.getNumDims());
659e4853be2SMehdi Amini   for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
6605994201cSMaheshRavishankar     auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
6615994201cSMaheshRavishankar     reassociation.emplace_back(seq.begin(), seq.end());
6625994201cSMaheshRavishankar     sum += numFoldedDim.value();
6635994201cSMaheshRavishankar   }
6645994201cSMaheshRavishankar   expandedOpNumDims = sum;
6655994201cSMaheshRavishankar   return success();
6665994201cSMaheshRavishankar }
6675994201cSMaheshRavishankar 
66897069a86SGaurav Shukla /// Expanding the body of a linalg operation requires adaptations of the
66997069a86SGaurav Shukla /// accessed loop indices. Specifically, access of indices in the original
67097069a86SGaurav Shukla /// operation need to be replaced with linearizations of indices in the expanded
67197069a86SGaurav Shukla /// op. That requires the shape of the expanded dimensions to be static (at
67297069a86SGaurav Shukla /// least all but the most significant). For now check that these are all
67397069a86SGaurav Shukla /// statically sized. Note that this could be extended to handle dynamic case,
67497069a86SGaurav Shukla /// but the implementation below uses `affine.apply` which seems to have issues
67597069a86SGaurav Shukla /// when the shapes are not static.
6763f18f6a2SQuinn Dawkins static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
6775994201cSMaheshRavishankar                                           const ExpansionInfo &expansionInfo,
6785994201cSMaheshRavishankar                                           PatternRewriter &rewriter) {
6793f18f6a2SQuinn Dawkins   if (!linalgOp.hasIndexSemantics())
6805994201cSMaheshRavishankar     return success();
6815994201cSMaheshRavishankar   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
6825994201cSMaheshRavishankar     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
6835994201cSMaheshRavishankar     if (expandedShape.size() == 1)
6845994201cSMaheshRavishankar       continue;
6855994201cSMaheshRavishankar     for (int64_t shape : expandedShape.drop_front()) {
6865994201cSMaheshRavishankar       if (ShapedType::isDynamic(shape)) {
6875994201cSMaheshRavishankar         return rewriter.notifyMatchFailure(
6883f18f6a2SQuinn Dawkins             linalgOp, "cannot expand due to index semantics and dynamic dims");
6895994201cSMaheshRavishankar       }
6905994201cSMaheshRavishankar     }
6915994201cSMaheshRavishankar   }
6925994201cSMaheshRavishankar   return success();
6935994201cSMaheshRavishankar }
6945994201cSMaheshRavishankar 
6955994201cSMaheshRavishankar /// Return the indexing map to use in the expanded op for a given the
6965994201cSMaheshRavishankar /// `indexingMap` of the original operation.
6975994201cSMaheshRavishankar static AffineMap
6985994201cSMaheshRavishankar getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
6995994201cSMaheshRavishankar                            const ExpansionInfo &expansionInfo) {
7005994201cSMaheshRavishankar   SmallVector<AffineExpr> newExprs;
7015994201cSMaheshRavishankar   for (AffineExpr expr : indexingMap.getResults()) {
7021609f1c2Slong.chen     unsigned pos = cast<AffineDimExpr>(expr).getPosition();
7035994201cSMaheshRavishankar     SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
7045994201cSMaheshRavishankar         llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
7055994201cSMaheshRavishankar           return builder.getAffineDimExpr(static_cast<unsigned>(v));
7065994201cSMaheshRavishankar         }));
7075994201cSMaheshRavishankar     newExprs.append(expandedExprs.begin(), expandedExprs.end());
7085994201cSMaheshRavishankar   }
7095994201cSMaheshRavishankar   return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
7105994201cSMaheshRavishankar                         indexingMap.getNumSymbols(), newExprs,
7115994201cSMaheshRavishankar                         builder.getContext());
7125994201cSMaheshRavishankar }
7135994201cSMaheshRavishankar 
7145994201cSMaheshRavishankar /// Return the type of the operand/result to use in the expanded op given the
7155994201cSMaheshRavishankar /// type in the original op.
7165994201cSMaheshRavishankar static RankedTensorType getExpandedType(RankedTensorType originalType,
7175994201cSMaheshRavishankar                                         AffineMap indexingMap,
7185994201cSMaheshRavishankar                                         const ExpansionInfo &expansionInfo) {
7195994201cSMaheshRavishankar   SmallVector<int64_t> expandedShape;
7205994201cSMaheshRavishankar   for (AffineExpr expr : indexingMap.getResults()) {
7211609f1c2Slong.chen     unsigned dim = cast<AffineDimExpr>(expr).getPosition();
7225994201cSMaheshRavishankar     auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
7235994201cSMaheshRavishankar     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
7245994201cSMaheshRavishankar   }
7255994201cSMaheshRavishankar   return RankedTensorType::get(expandedShape, originalType.getElementType());
7265994201cSMaheshRavishankar }
7275994201cSMaheshRavishankar 
728206365bfSAlexander Belyaev /// Returns the reassociation maps to use in the `tensor.expand_shape`
7295994201cSMaheshRavishankar /// operation to convert the operands of the original operation to operands of
7305994201cSMaheshRavishankar /// the expanded operation. The same method is used to compute the
731206365bfSAlexander Belyaev /// `tensor.collapse_shape` used to collapse the result of the expanded
7325994201cSMaheshRavishankar /// op to get the value that can replace all uses of the results of the original
7335994201cSMaheshRavishankar /// op.
7345994201cSMaheshRavishankar static SmallVector<ReassociationIndices>
7355994201cSMaheshRavishankar getReassociationForExpansion(AffineMap indexingMap,
7365994201cSMaheshRavishankar                              const ExpansionInfo &expansionInfo) {
7375994201cSMaheshRavishankar   SmallVector<ReassociationIndices> reassociation;
7385994201cSMaheshRavishankar   unsigned numReshapeDims = 0;
7395994201cSMaheshRavishankar   for (AffineExpr expr : indexingMap.getResults()) {
7401609f1c2Slong.chen     unsigned dim = cast<AffineDimExpr>(expr).getPosition();
7415994201cSMaheshRavishankar     auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
7425994201cSMaheshRavishankar     SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
7435994201cSMaheshRavishankar         llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
7445994201cSMaheshRavishankar     reassociation.emplace_back(std::move(indices));
7455994201cSMaheshRavishankar     numReshapeDims += numExpandedDims;
7465994201cSMaheshRavishankar   }
7475994201cSMaheshRavishankar   return reassociation;
7485994201cSMaheshRavishankar }
7495994201cSMaheshRavishankar 
7505994201cSMaheshRavishankar /// Update the body of an expanded linalg operation having index semantics. The
7515994201cSMaheshRavishankar /// indices of the original operation need to be recovered by linearizing the
7525994201cSMaheshRavishankar /// indices of the correspoding dimensions of the expanded operation. For now it
7535994201cSMaheshRavishankar /// is assumed that the shapes of the expanded operation needed for
7545994201cSMaheshRavishankar /// linearization are static.
7555994201cSMaheshRavishankar static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
7565994201cSMaheshRavishankar                                           Location loc, Region &fusedRegion,
7575994201cSMaheshRavishankar                                           const ExpansionInfo &expansionInfo) {
7585994201cSMaheshRavishankar   // Replace the original indices by the linearization of the expanded indices.
7595994201cSMaheshRavishankar   for (IndexOp indexOp :
7605994201cSMaheshRavishankar        llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
7615994201cSMaheshRavishankar     ArrayRef<int64_t> expandedDims =
762d3b3f765SJacques Pienaar         expansionInfo.getExpandedDims(indexOp.getDim());
7635994201cSMaheshRavishankar     assert(!expandedDims.empty() && "expected valid expansion info");
7645994201cSMaheshRavishankar 
7655994201cSMaheshRavishankar     // Skip index operations that are not affected by the expansion.
7665994201cSMaheshRavishankar     if (expandedDims.size() == 1 &&
767d3b3f765SJacques Pienaar         expandedDims.front() == (int64_t)indexOp.getDim())
7685994201cSMaheshRavishankar       continue;
7695994201cSMaheshRavishankar 
7705994201cSMaheshRavishankar     // Linearize the expanded indices of the original index dimension.
7715994201cSMaheshRavishankar     OpBuilder::InsertionGuard guard(rewriter);
7725994201cSMaheshRavishankar     rewriter.setInsertionPointAfter(indexOp);
7735994201cSMaheshRavishankar     ArrayRef<int64_t> expandedDimsShape =
774d3b3f765SJacques Pienaar         expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
7755994201cSMaheshRavishankar     SmallVector<Value> expandedIndices;
7765994201cSMaheshRavishankar     expandedIndices.reserve(expandedDims.size() - 1);
7775994201cSMaheshRavishankar     llvm::transform(
7785994201cSMaheshRavishankar         expandedDims.drop_front(), std::back_inserter(expandedIndices),
7795994201cSMaheshRavishankar         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
7805994201cSMaheshRavishankar     Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
7815994201cSMaheshRavishankar     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
7825994201cSMaheshRavishankar       assert(!ShapedType::isDynamic(std::get<0>(it)));
7835994201cSMaheshRavishankar       AffineExpr idx, acc;
7845994201cSMaheshRavishankar       bindDims(rewriter.getContext(), idx, acc);
7854c48f016SMatthias Springer       newIndex = rewriter.create<affine::AffineApplyOp>(
7865994201cSMaheshRavishankar           indexOp.getLoc(), idx + acc * std::get<0>(it),
7875994201cSMaheshRavishankar           ValueRange{std::get<1>(it), newIndex});
7885994201cSMaheshRavishankar     }
7895994201cSMaheshRavishankar     rewriter.replaceOp(indexOp, newIndex);
7905994201cSMaheshRavishankar   }
7915994201cSMaheshRavishankar }
7925994201cSMaheshRavishankar 
79397069a86SGaurav Shukla /// Checks if a single dynamic dimension expanded into multiple dynamic
79497069a86SGaurav Shukla /// dimensions.
79597069a86SGaurav Shukla static LogicalResult
79697069a86SGaurav Shukla validateDynamicDimExpansion(LinalgOp linalgOp,
79797069a86SGaurav Shukla                             const ExpansionInfo &expansionInfo,
79897069a86SGaurav Shukla                             PatternRewriter &rewriter) {
79997069a86SGaurav Shukla   for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
80097069a86SGaurav Shukla     ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
80197069a86SGaurav Shukla     if (expandedShape.size() == 1)
80297069a86SGaurav Shukla       continue;
80397069a86SGaurav Shukla     bool foundDynamic = false;
80497069a86SGaurav Shukla     for (int64_t shape : expandedShape) {
80597069a86SGaurav Shukla       if (!ShapedType::isDynamic(shape))
80697069a86SGaurav Shukla         continue;
80797069a86SGaurav Shukla       if (foundDynamic) {
80897069a86SGaurav Shukla         return rewriter.notifyMatchFailure(
80997069a86SGaurav Shukla             linalgOp, "cannot infer expanded shape with multiple dynamic "
81097069a86SGaurav Shukla                       "dims in the same reassociation group");
81197069a86SGaurav Shukla       }
81297069a86SGaurav Shukla       foundDynamic = true;
81397069a86SGaurav Shukla     }
81497069a86SGaurav Shukla   }
81597069a86SGaurav Shukla   return success();
81697069a86SGaurav Shukla }
81797069a86SGaurav Shukla 
8181ad9b266Slorenzo chelini /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
8195994201cSMaheshRavishankar /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
8205994201cSMaheshRavishankar /// that those conditions have been satisfied.
8210a81ace0SKazu Hirata static std::optional<SmallVector<Value>>
8223f18f6a2SQuinn Dawkins fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
8235994201cSMaheshRavishankar                            OpOperand *fusableOpOperand,
8245994201cSMaheshRavishankar                            PatternRewriter &rewriter) {
8253f18f6a2SQuinn Dawkins   assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
8265994201cSMaheshRavishankar          "preconditions for fuse operation failed");
82797069a86SGaurav Shukla 
82897069a86SGaurav Shukla   Location loc = linalgOp.getLoc();
8295994201cSMaheshRavishankar   // Check if reshape is expanding or collapsing.
830b618880eSAlexander Belyaev   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
831b618880eSAlexander Belyaev   auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
8325994201cSMaheshRavishankar   bool isExpanding = (expandingReshapeOp != nullptr);
8335994201cSMaheshRavishankar   RankedTensorType expandedType = isExpanding
8345994201cSMaheshRavishankar                                       ? expandingReshapeOp.getResultType()
8355994201cSMaheshRavishankar                                       : collapsingReshapeOp.getSrcType();
8364317a3dfSMaheshRavishankar   RankedTensorType collapsedType = isExpanding
8374317a3dfSMaheshRavishankar                                        ? expandingReshapeOp.getSrcType()
8384317a3dfSMaheshRavishankar                                        : collapsingReshapeOp.getResultType();
8395994201cSMaheshRavishankar 
8405994201cSMaheshRavishankar   ExpansionInfo expansionInfo;
8415994201cSMaheshRavishankar   if (failed(expansionInfo.compute(
8423f18f6a2SQuinn Dawkins           linalgOp, fusableOpOperand,
8435994201cSMaheshRavishankar           isExpanding ? expandingReshapeOp.getReassociationMaps()
8445994201cSMaheshRavishankar                       : collapsingReshapeOp.getReassociationMaps(),
8454317a3dfSMaheshRavishankar           expandedType.getShape(), collapsedType.getShape(), rewriter)))
8461a36588eSKazu Hirata     return std::nullopt;
8475994201cSMaheshRavishankar 
84897069a86SGaurav Shukla   // TODO: With the support of multiple dynamic dims expansion in
84997069a86SGaurav Shukla   // tensor.expand_shape op, this case can be handled.
85097069a86SGaurav Shukla   if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
85197069a86SGaurav Shukla     return std::nullopt;
85297069a86SGaurav Shukla 
8533f18f6a2SQuinn Dawkins   if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
8541a36588eSKazu Hirata     return std::nullopt;
8555994201cSMaheshRavishankar 
8565994201cSMaheshRavishankar   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
8573f18f6a2SQuinn Dawkins       llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) {
8585994201cSMaheshRavishankar         return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
8595994201cSMaheshRavishankar       }));
8605994201cSMaheshRavishankar 
861a7bfdc23SMahesh Ravishankar   // Set insertion point to the generic op.
862a7bfdc23SMahesh Ravishankar   OpBuilder::InsertionGuard g(rewriter);
8633f18f6a2SQuinn Dawkins   rewriter.setInsertionPoint(linalgOp);
864a7bfdc23SMahesh Ravishankar 
8655994201cSMaheshRavishankar   SmallVector<Value> expandedOpOperands;
8663f18f6a2SQuinn Dawkins   expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
8673f18f6a2SQuinn Dawkins   for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
8685994201cSMaheshRavishankar     if (opOperand == fusableOpOperand) {
86904235d07SJacques Pienaar       expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
87004235d07SJacques Pienaar                                                : collapsingReshapeOp.getSrc());
8715994201cSMaheshRavishankar       continue;
8725994201cSMaheshRavishankar     }
873a7cccb9cSAlexander Belyaev     if (auto opOperandType =
8745550c821STres Popp             dyn_cast<RankedTensorType>(opOperand->get().getType())) {
8753f18f6a2SQuinn Dawkins       AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
8765994201cSMaheshRavishankar       RankedTensorType expandedOperandType =
877ff5de8a9SBenjamin Kramer           getExpandedType(opOperandType, indexingMap, expansionInfo);
8785994201cSMaheshRavishankar       if (expandedOperandType != opOperand->get().getType()) {
8795994201cSMaheshRavishankar         // Reshape the operand to get the right type.
8805994201cSMaheshRavishankar         SmallVector<ReassociationIndices> reassociation =
8815994201cSMaheshRavishankar             getReassociationForExpansion(indexingMap, expansionInfo);
882ff5de8a9SBenjamin Kramer         if (failed(reshapeLikeShapesAreCompatible(
883ff5de8a9SBenjamin Kramer                 [&](const Twine &msg) {
8843f18f6a2SQuinn Dawkins                   return rewriter.notifyMatchFailure(linalgOp, msg);
885ff5de8a9SBenjamin Kramer                 },
886ff5de8a9SBenjamin Kramer                 opOperandType.getShape(), expandedOperandType.getShape(),
887ff5de8a9SBenjamin Kramer                 reassociation,
888ff5de8a9SBenjamin Kramer                 /*isExpandingReshape=*/true)))
8891a36588eSKazu Hirata           return std::nullopt;
890b618880eSAlexander Belyaev         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
89197069a86SGaurav Shukla             loc, expandedOperandType, opOperand->get(), reassociation));
8925994201cSMaheshRavishankar         continue;
8935994201cSMaheshRavishankar       }
8945994201cSMaheshRavishankar     }
8955994201cSMaheshRavishankar     expandedOpOperands.push_back(opOperand->get());
8965994201cSMaheshRavishankar   }
8975994201cSMaheshRavishankar 
8985994201cSMaheshRavishankar   SmallVector<Value> outputs;
8993f18f6a2SQuinn Dawkins   for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
9003f18f6a2SQuinn Dawkins     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
9010b2197b0SMatthias Springer     auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
9025994201cSMaheshRavishankar     RankedTensorType expandedOutputType =
903ff5de8a9SBenjamin Kramer         getExpandedType(opOperandType, indexingMap, expansionInfo);
9040b2197b0SMatthias Springer     if (expandedOutputType != opOperand.get().getType()) {
9055994201cSMaheshRavishankar       SmallVector<ReassociationIndices> reassociation =
9065994201cSMaheshRavishankar           getReassociationForExpansion(indexingMap, expansionInfo);
907ff5de8a9SBenjamin Kramer       if (failed(reshapeLikeShapesAreCompatible(
908ff5de8a9SBenjamin Kramer               [&](const Twine &msg) {
9093f18f6a2SQuinn Dawkins                 return rewriter.notifyMatchFailure(linalgOp, msg);
910ff5de8a9SBenjamin Kramer               },
911ff5de8a9SBenjamin Kramer               opOperandType.getShape(), expandedOutputType.getShape(),
912ff5de8a9SBenjamin Kramer               reassociation,
913ff5de8a9SBenjamin Kramer               /*isExpandingReshape=*/true)))
9141a36588eSKazu Hirata         return std::nullopt;
915b618880eSAlexander Belyaev       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
91697069a86SGaurav Shukla           loc, expandedOutputType, opOperand.get(), reassociation));
91771604f4cSMahesh Ravishankar     } else {
9180b2197b0SMatthias Springer       outputs.push_back(opOperand.get());
9195994201cSMaheshRavishankar     }
9205994201cSMaheshRavishankar   }
9215994201cSMaheshRavishankar 
9225994201cSMaheshRavishankar   // The iterator types of the expanded op are all parallel.
923e6598b05SOleg Shyshkov   SmallVector<utils::IteratorType> iteratorTypes(
924e6598b05SOleg Shyshkov       expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
9253f18f6a2SQuinn Dawkins   for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
9263f18f6a2SQuinn Dawkins     for (auto j : expansionInfo.getExpandedDims(i))
9273f18f6a2SQuinn Dawkins       iteratorTypes[j] = type;
9285994201cSMaheshRavishankar 
9295994201cSMaheshRavishankar   TypeRange resultTypes = ValueRange(outputs).getTypes();
9305994201cSMaheshRavishankar   auto fusedOp =
9313f18f6a2SQuinn Dawkins       rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
9325994201cSMaheshRavishankar                                  /*inputs=*/expandedOpOperands, outputs,
9335994201cSMaheshRavishankar                                  expandedOpIndexingMaps, iteratorTypes);
9345994201cSMaheshRavishankar   Region &fusedRegion = fusedOp->getRegion(0);
9353f18f6a2SQuinn Dawkins   Region &originalRegion = linalgOp->getRegion(0);
9365994201cSMaheshRavishankar   rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
9375994201cSMaheshRavishankar 
9385994201cSMaheshRavishankar   // Update the index accesses after the expansion.
9395994201cSMaheshRavishankar   updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
9405994201cSMaheshRavishankar 
9415994201cSMaheshRavishankar   // Reshape the result values to their original shape if this is a collapsing
9425994201cSMaheshRavishankar   // reshape folded into its consumer.
9435994201cSMaheshRavishankar   SmallVector<Value> resultVals;
9443f18f6a2SQuinn Dawkins   for (OpResult opResult : linalgOp->getOpResults()) {
9455994201cSMaheshRavishankar     int64_t resultNumber = opResult.getResultNumber();
946a7bfdc23SMahesh Ravishankar     if (resultTypes[resultNumber] != opResult.getType()) {
9475994201cSMaheshRavishankar       SmallVector<ReassociationIndices> reassociation =
9485994201cSMaheshRavishankar           getReassociationForExpansion(
9493f18f6a2SQuinn Dawkins               linalgOp.getMatchingIndexingMap(
9503f18f6a2SQuinn Dawkins                   linalgOp.getDpsInitOperand(resultNumber)),
9515994201cSMaheshRavishankar               expansionInfo);
952b618880eSAlexander Belyaev       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
9533f18f6a2SQuinn Dawkins           linalgOp.getLoc(), opResult.getType(),
9545994201cSMaheshRavishankar           fusedOp->getResult(resultNumber), reassociation));
9555994201cSMaheshRavishankar     } else {
9565994201cSMaheshRavishankar       resultVals.push_back(fusedOp->getResult(resultNumber));
9575994201cSMaheshRavishankar     }
9585994201cSMaheshRavishankar   }
9595994201cSMaheshRavishankar   // Assuming a single result.
9605994201cSMaheshRavishankar   return resultVals;
9615994201cSMaheshRavishankar }
9625994201cSMaheshRavishankar 
9635994201cSMaheshRavishankar namespace {
9645994201cSMaheshRavishankar 
9653f18f6a2SQuinn Dawkins /// Pattern to fuse a tensor.collapse_shape op with its consumer structured op,
96632288d37SMahesh Ravishankar /// when the reshape op is collapsing dimensions. The dimensionality of the loop
96732288d37SMahesh Ravishankar /// in the consumer is expanded.
96832288d37SMahesh Ravishankar class FoldWithProducerReshapeOpByExpansion
9693f18f6a2SQuinn Dawkins     : public OpInterfaceRewritePattern<LinalgOp> {
97032288d37SMahesh Ravishankar public:
9712291705dSMahesh Ravishankar   FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
9722291705dSMahesh Ravishankar                                        ControlFusionFn foldReshapes,
97332288d37SMahesh Ravishankar                                        PatternBenefit benefit = 1)
9743f18f6a2SQuinn Dawkins       : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
97532288d37SMahesh Ravishankar         controlFoldingReshapes(std::move(foldReshapes)) {}
9765994201cSMaheshRavishankar 
9773f18f6a2SQuinn Dawkins   LogicalResult matchAndRewrite(LinalgOp linalgOp,
9785994201cSMaheshRavishankar                                 PatternRewriter &rewriter) const override {
9793f18f6a2SQuinn Dawkins     for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
98032288d37SMahesh Ravishankar       tensor::CollapseShapeOp reshapeOp =
98132288d37SMahesh Ravishankar           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
9825994201cSMaheshRavishankar       if (!reshapeOp)
9835994201cSMaheshRavishankar         continue;
98432288d37SMahesh Ravishankar       // Fold only if
98532288d37SMahesh Ravishankar       // - The tensor reshape op is folding.
98632288d37SMahesh Ravishankar       // - All constraints of fusing with reshape by expansion are met.
9873f18f6a2SQuinn Dawkins       if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) ||
988a7bfdc23SMahesh Ravishankar           (!controlFoldingReshapes(opOperand)))
9895994201cSMaheshRavishankar         continue;
9905994201cSMaheshRavishankar 
9910a81ace0SKazu Hirata       std::optional<SmallVector<Value>> replacementValues =
9923f18f6a2SQuinn Dawkins           fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter);
99332288d37SMahesh Ravishankar       if (!replacementValues)
9945994201cSMaheshRavishankar         return failure();
9953f18f6a2SQuinn Dawkins       rewriter.replaceOp(linalgOp, *replacementValues);
9965994201cSMaheshRavishankar       return success();
9975994201cSMaheshRavishankar     }
9985994201cSMaheshRavishankar     return failure();
9995994201cSMaheshRavishankar   }
100032288d37SMahesh Ravishankar 
100132288d37SMahesh Ravishankar private:
10022291705dSMahesh Ravishankar   ControlFusionFn controlFoldingReshapes;
10035994201cSMaheshRavishankar };
10045994201cSMaheshRavishankar 
1005c886d66dSMax191 class FoldPadWithProducerReshapeOpByExpansion
1006c886d66dSMax191     : public OpRewritePattern<tensor::PadOp> {
1007c886d66dSMax191 public:
1008c886d66dSMax191   FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
1009c886d66dSMax191                                           ControlFusionFn foldReshapes,
1010c886d66dSMax191                                           PatternBenefit benefit = 1)
1011c886d66dSMax191       : OpRewritePattern<tensor::PadOp>(context, benefit),
1012c886d66dSMax191         controlFoldingReshapes(std::move(foldReshapes)) {}
1013c886d66dSMax191 
1014c886d66dSMax191   LogicalResult matchAndRewrite(tensor::PadOp padOp,
1015c886d66dSMax191                                 PatternRewriter &rewriter) const override {
1016c886d66dSMax191     tensor::CollapseShapeOp reshapeOp =
1017c886d66dSMax191         padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
1018c886d66dSMax191     if (!reshapeOp)
1019c886d66dSMax191       return failure();
1020c886d66dSMax191     if (!reshapeOp->hasOneUse())
1021c886d66dSMax191       return failure();
1022c886d66dSMax191 
1023c886d66dSMax191     if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1024c886d66dSMax191       return rewriter.notifyMatchFailure(padOp,
1025c886d66dSMax191                                          "fusion blocked by control function");
1026c886d66dSMax191     }
1027c886d66dSMax191 
1028c886d66dSMax191     ArrayRef<int64_t> low = padOp.getStaticLow();
1029c886d66dSMax191     ArrayRef<int64_t> high = padOp.getStaticHigh();
1030c886d66dSMax191     SmallVector<ReassociationIndices> reassociations =
1031c886d66dSMax191         reshapeOp.getReassociationIndices();
1032c886d66dSMax191 
1033c886d66dSMax191     for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
1034c886d66dSMax191       if (reInd.size() != 1 && (l != 0 || h != 0))
1035c886d66dSMax191         return failure();
1036c886d66dSMax191     }
1037c886d66dSMax191 
1038c886d66dSMax191     SmallVector<OpFoldResult> newLow, newHigh;
1039c886d66dSMax191     RankedTensorType expandedType = reshapeOp.getSrcType();
1040c886d66dSMax191     RankedTensorType paddedType = padOp.getResultType();
1041c886d66dSMax191     SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
1042c886d66dSMax191     for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1043c886d66dSMax191       if (reInd.size() == 1) {
1044c886d66dSMax191         expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
1045c886d66dSMax191       }
1046c886d66dSMax191       for (size_t i = 0; i < reInd.size(); ++i) {
1047c886d66dSMax191         newLow.push_back(padOp.getMixedLowPad()[idx]);
1048c886d66dSMax191         newHigh.push_back(padOp.getMixedHighPad()[idx]);
1049c886d66dSMax191       }
1050c886d66dSMax191     }
1051c886d66dSMax191 
1052c886d66dSMax191     Location loc = padOp->getLoc();
1053c886d66dSMax191     RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
1054c886d66dSMax191     auto newPadOp = rewriter.create<tensor::PadOp>(
1055c886d66dSMax191         loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1056c886d66dSMax191         padOp.getConstantPaddingValue(), padOp.getNofold());
1057c886d66dSMax191 
1058c886d66dSMax191     rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1059c886d66dSMax191         padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
1060c886d66dSMax191 
1061c886d66dSMax191     return success();
1062c886d66dSMax191   }
1063c886d66dSMax191 
1064c886d66dSMax191 private:
1065c886d66dSMax191   ControlFusionFn controlFoldingReshapes;
1066c886d66dSMax191 };
1067c886d66dSMax191 
10681ad9b266Slorenzo chelini /// Pattern to fold a tensor.expand_shape op with its producer generic op
106932288d37SMahesh Ravishankar /// by expanding the dimensionality of the loop in the producer op.
107032288d37SMahesh Ravishankar struct FoldReshapeWithGenericOpByExpansion
107132288d37SMahesh Ravishankar     : public OpRewritePattern<tensor::ExpandShapeOp> {
107232288d37SMahesh Ravishankar 
10732291705dSMahesh Ravishankar   FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
10742291705dSMahesh Ravishankar                                       ControlFusionFn foldReshapes,
107532288d37SMahesh Ravishankar                                       PatternBenefit benefit = 1)
107632288d37SMahesh Ravishankar       : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
107732288d37SMahesh Ravishankar         controlFoldingReshapes(std::move(foldReshapes)) {}
107832288d37SMahesh Ravishankar 
107932288d37SMahesh Ravishankar   LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
108032288d37SMahesh Ravishankar                                 PatternRewriter &rewriter) const override {
108132288d37SMahesh Ravishankar     // Fold only if all constraints of fusing with reshape by expansion are met.
10825550c821STres Popp     auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
1083a7bfdc23SMahesh Ravishankar     if (!producerResult) {
1084a7bfdc23SMahesh Ravishankar       return rewriter.notifyMatchFailure(reshapeOp,
1085a7bfdc23SMahesh Ravishankar                                          "source not produced by an operation");
1086a7bfdc23SMahesh Ravishankar     }
1087a7bfdc23SMahesh Ravishankar 
10883f18f6a2SQuinn Dawkins     auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
1089a7bfdc23SMahesh Ravishankar     if (!producer) {
1090a7bfdc23SMahesh Ravishankar       return rewriter.notifyMatchFailure(reshapeOp,
1091a7bfdc23SMahesh Ravishankar                                          "producer not a generic op");
1092a7bfdc23SMahesh Ravishankar     }
1093a7bfdc23SMahesh Ravishankar 
1094a7bfdc23SMahesh Ravishankar     if (!isFusableWithReshapeByDimExpansion(
1095a7bfdc23SMahesh Ravishankar             producer,
1096b4db15a9SAlexander Belyaev             producer.getDpsInitOperand(producerResult.getResultNumber()))) {
1097a7bfdc23SMahesh Ravishankar       return rewriter.notifyMatchFailure(
1098a7bfdc23SMahesh Ravishankar           reshapeOp, "failed preconditions of fusion with producer generic op");
1099a7bfdc23SMahesh Ravishankar     }
1100a7bfdc23SMahesh Ravishankar 
11018823e961SMatthias Springer     if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
1102a7bfdc23SMahesh Ravishankar       return rewriter.notifyMatchFailure(reshapeOp,
1103a7bfdc23SMahesh Ravishankar                                          "fusion blocked by control function");
1104a7bfdc23SMahesh Ravishankar     }
1105a7bfdc23SMahesh Ravishankar 
11060a81ace0SKazu Hirata     std::optional<SmallVector<Value>> replacementValues =
11070a81ace0SKazu Hirata         fuseWithReshapeByExpansion(
1108a7bfdc23SMahesh Ravishankar             producer, reshapeOp,
11090a81ace0SKazu Hirata             producer.getDpsInitOperand(producerResult.getResultNumber()),
11100a81ace0SKazu Hirata             rewriter);
1111a7bfdc23SMahesh Ravishankar     if (!replacementValues) {
1112a7bfdc23SMahesh Ravishankar       return rewriter.notifyMatchFailure(reshapeOp,
1113a7bfdc23SMahesh Ravishankar                                          "fusion by expansion failed");
1114a7bfdc23SMahesh Ravishankar     }
1115a7bfdc23SMahesh Ravishankar 
1116a7bfdc23SMahesh Ravishankar     // Find the replacement for the reshape op. Since the replacements have the
1117a7bfdc23SMahesh Ravishankar     // same type as the returns of the original generic op, the consumer reshape
1118a7bfdc23SMahesh Ravishankar     // op can be replaced by the source of the collapse_shape op that defines
1119a7bfdc23SMahesh Ravishankar     // the replacement.
11205550c821STres Popp     Value reshapeReplacement =
11215550c821STres Popp         (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
11225550c821STres Popp                                  .getResultNumber()];
1123a7bfdc23SMahesh Ravishankar     if (auto collapseOp =
1124a7bfdc23SMahesh Ravishankar             reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
1125a7bfdc23SMahesh Ravishankar       reshapeReplacement = collapseOp.getSrc();
1126a7bfdc23SMahesh Ravishankar     }
1127a7bfdc23SMahesh Ravishankar     rewriter.replaceOp(reshapeOp, reshapeReplacement);
1128a7bfdc23SMahesh Ravishankar     rewriter.replaceOp(producer, *replacementValues);
112932288d37SMahesh Ravishankar     return success();
113032288d37SMahesh Ravishankar   }
113132288d37SMahesh Ravishankar 
113232288d37SMahesh Ravishankar private:
11332291705dSMahesh Ravishankar   ControlFusionFn controlFoldingReshapes;
113432288d37SMahesh Ravishankar };
113532288d37SMahesh Ravishankar } // namespace
113632288d37SMahesh Ravishankar 
113732288d37SMahesh Ravishankar //===---------------------------------------------------------------------===//
11382c58cde0SMahesh Ravishankar // Methods and patterns to fuse reshape with linalg.generic operations by
11392c58cde0SMahesh Ravishankar // contraction of dimensions.
11402c58cde0SMahesh Ravishankar //===---------------------------------------------------------------------===//
11412c58cde0SMahesh Ravishankar 
1142b40e9013SMahesh Ravishankar /// For a given list of indices in the range of the `indexingMap` that are
1143192d9dd7SKazu Hirata /// folded, return the indices of the corresponding domain. Return
1144192d9dd7SKazu Hirata /// `std::nullopt` on failure. Ensures that all the elements of the returned
1145192d9dd7SKazu Hirata /// reassociation are distinct.
1146b40e9013SMahesh Ravishankar static ReassociationIndices
11472c58cde0SMahesh Ravishankar getDomainReassociation(AffineMap indexingMap,
1148b40e9013SMahesh Ravishankar                        ReassociationIndicesRef rangeReassociation) {
11492c58cde0SMahesh Ravishankar   assert(indexingMap.isProjectedPermutation() &&
1150b40e9013SMahesh Ravishankar          "expected projected permutation");
11512c58cde0SMahesh Ravishankar 
1152b40e9013SMahesh Ravishankar   ReassociationIndices domainReassociation = llvm::to_vector<4>(
1153b40e9013SMahesh Ravishankar       llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
11541609f1c2Slong.chen         return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
1155b40e9013SMahesh Ravishankar       }));
1156b40e9013SMahesh Ravishankar   // The projected permutation semantics ensures that there is no repetition of
1157b40e9013SMahesh Ravishankar   // the domain indices.
11582c58cde0SMahesh Ravishankar   return domainReassociation;
11592c58cde0SMahesh Ravishankar }
11602c58cde0SMahesh Ravishankar 
11612c58cde0SMahesh Ravishankar /// For a given `dimSequence`, check if the sequence is conserved in the
11622c58cde0SMahesh Ravishankar /// `indexingMap`. `indexingMap` is expected to be a projected permutation.
11632c58cde0SMahesh Ravishankar /// Non-existence of the sequence returns true as well.
1164f12639d0SMahesh Ravishankar bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
11652c58cde0SMahesh Ravishankar                                           ReassociationIndicesRef dimSequence) {
11662c58cde0SMahesh Ravishankar   assert(!dimSequence.empty() &&
11672c58cde0SMahesh Ravishankar          "expected non-empty list for dimension sequence");
11682c58cde0SMahesh Ravishankar   assert(indexingMap.isProjectedPermutation() &&
11692c58cde0SMahesh Ravishankar          "expected indexing map to be projected permutation");
11702c58cde0SMahesh Ravishankar 
11712c58cde0SMahesh Ravishankar   llvm::SmallDenseSet<unsigned, 4> sequenceElements;
11722c58cde0SMahesh Ravishankar   sequenceElements.insert(dimSequence.begin(), dimSequence.end());
11732c58cde0SMahesh Ravishankar 
11742c58cde0SMahesh Ravishankar   unsigned dimSequenceStart = dimSequence[0];
1175a91ade0bSAdrian Kuegel   for (const auto &expr : enumerate(indexingMap.getResults())) {
11761609f1c2Slong.chen     unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
11772c58cde0SMahesh Ravishankar     // 1.  Check if this start of the sequence.
11782c58cde0SMahesh Ravishankar     if (dimInMapStart == dimSequenceStart) {
11792c58cde0SMahesh Ravishankar       if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
11802c58cde0SMahesh Ravishankar         return false;
11812c58cde0SMahesh Ravishankar       // 1a. Check if sequence is preserved.
1182a91ade0bSAdrian Kuegel       for (const auto &dimInSequence : enumerate(dimSequence)) {
11832c58cde0SMahesh Ravishankar         unsigned dimInMap =
11841609f1c2Slong.chen             cast<AffineDimExpr>(
11851609f1c2Slong.chen                 indexingMap.getResult(expr.index() + dimInSequence.index()))
11862c58cde0SMahesh Ravishankar                 .getPosition();
11872c58cde0SMahesh Ravishankar         if (dimInMap != dimInSequence.value())
11882c58cde0SMahesh Ravishankar           return false;
11892c58cde0SMahesh Ravishankar       }
11902c58cde0SMahesh Ravishankar       // Found the sequence. Projected permutation
11912c58cde0SMahesh Ravishankar       // enforces that all AffineDimExprs in the result are unique, so no
11922c58cde0SMahesh Ravishankar       // further checks are needed.
11932c58cde0SMahesh Ravishankar       return true;
11942c58cde0SMahesh Ravishankar     }
11952c58cde0SMahesh Ravishankar     // 2. If position in the expr (which is of type AffineDimExpr) is part
11962c58cde0SMahesh Ravishankar     // of sequence, return false here. This implies the entire sequence does not
11972c58cde0SMahesh Ravishankar     // exist in the indexing map.
11982c58cde0SMahesh Ravishankar     if (sequenceElements.count(dimInMapStart))
11992c58cde0SMahesh Ravishankar       return false;
12002c58cde0SMahesh Ravishankar   }
12012c58cde0SMahesh Ravishankar   // 3. No element of sequence found. Return true.
12022c58cde0SMahesh Ravishankar   return true;
12032c58cde0SMahesh Ravishankar }
12042c58cde0SMahesh Ravishankar 
1205f12639d0SMahesh Ravishankar bool mlir::linalg::areDimSequencesPreserved(
1206f12639d0SMahesh Ravishankar     ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) {
1207f12639d0SMahesh Ravishankar   return llvm::all_of(maps, [&](AffineMap map) {
1208f12639d0SMahesh Ravishankar     return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
1209f12639d0SMahesh Ravishankar       return isDimSequencePreserved(map, dimSequence);
1210f12639d0SMahesh Ravishankar     });
1211f12639d0SMahesh Ravishankar   });
1212f12639d0SMahesh Ravishankar }
1213f12639d0SMahesh Ravishankar 
1214b40e9013SMahesh Ravishankar // Return the list of dimensions of the iteration domain that can be
1215b40e9013SMahesh Ravishankar // collapsed to allow for fusion with the a producer that is an expand_shape
1216b40e9013SMahesh Ravishankar // operation. If all dimensions created by expansion can be collapsed in the
1217b40e9013SMahesh Ravishankar // iteration space then the reshape is defunct.
1218b40e9013SMahesh Ravishankar //
1219b40e9013SMahesh Ravishankar // Example:
1220b40e9013SMahesh Ravishankar //
1221b40e9013SMahesh Ravishankar // ```mlir
1222b40e9013SMahesh Ravishankar // #map = affine_map<(d0, d1) -> (d0, d1)>
1223b40e9013SMahesh Ravishankar // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
122481ca5aa4SMatthias Springer // %2 = tensor.empty [..] : tensor<?x4xf32>
1225b40e9013SMahesh Ravishankar // %3 = linalg.generic {
1226b40e9013SMahesh Ravishankar //     indexing_maps = [#map, #map],
1227b40e9013SMahesh Ravishankar //     iterator_types = ["parallel" ,"parallel"]}
1228b40e9013SMahesh Ravishankar //     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
1229b40e9013SMahesh Ravishankar // ```
1230b40e9013SMahesh Ravishankar //
1231b40e9013SMahesh Ravishankar // can be fused by collapsing the dimensions of the iteration space.
1232b40e9013SMahesh Ravishankar //
1233b40e9013SMahesh Ravishankar // ```mlir
1234b40e9013SMahesh Ravishankar // #map = affine_map<(d0) -> (d0)>
123581ca5aa4SMatthias Springer // %2 = tensor.empty [..] : tensor<?xf32>
1236b40e9013SMahesh Ravishankar // %3 = linalg.generic {
1237b40e9013SMahesh Ravishankar //     indexing_maps = [#map, #map],
1238b40e9013SMahesh Ravishankar //     iterator_types = ["parallel"]}
1239b40e9013SMahesh Ravishankar //     ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
1240b40e9013SMahesh Ravishankar // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
1241b40e9013SMahesh Ravishankar // ```
1242b40e9013SMahesh Ravishankar //
1243b40e9013SMahesh Ravishankar // In the following example,
1244b40e9013SMahesh Ravishankar //
1245b40e9013SMahesh Ravishankar // ```mlir
1246b40e9013SMahesh Ravishankar // #map0 = affine_map<(d0, d1) -> (d0, d1)>
1247b40e9013SMahesh Ravishankar // #map1 = affine_map<(d0, d1) -> (d1, d0)>
1248b40e9013SMahesh Ravishankar // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
124981ca5aa4SMatthias Springer // %2 = tensor.empty [..] : tensor<4x?xf32>
1250b40e9013SMahesh Ravishankar // %2 = linalg.generic {
1251b40e9013SMahesh Ravishankar //     indexing_maps = [#map0, #map1],
1252b40e9013SMahesh Ravishankar //     iterator_types = ["parallel" ,"parallel"]}
1253b40e9013SMahesh Ravishankar //     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
1254b40e9013SMahesh Ravishankar // ```
1255b40e9013SMahesh Ravishankar //
1256b40e9013SMahesh Ravishankar // the reshape cannot be fused with the generic op by collapsing the op
1257b40e9013SMahesh Ravishankar // dimensions since the indexing maps will have to contain mods and divs
1258b40e9013SMahesh Ravishankar // to preserve the accesses pattern. When no dimensions of the iteration
1259b40e9013SMahesh Ravishankar // space are collapsable and empty vector is returned.
1260b40e9013SMahesh Ravishankar static SmallVector<ReassociationIndices>
1261b40e9013SMahesh Ravishankar getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
12622c58cde0SMahesh Ravishankar                                  ArrayRef<ReassociationIndices> reassociation) {
12632c58cde0SMahesh Ravishankar   // Some basic checks for this fusion to be valid.
126408efa230SMax191   if (!genericOp.hasPureTensorSemantics())
1265b40e9013SMahesh Ravishankar     return {};
12662c58cde0SMahesh Ravishankar 
1267d2c0572bSJacques Pienaar   if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
12682c58cde0SMahesh Ravishankar         return map.isProjectedPermutation();
12692c58cde0SMahesh Ravishankar       })) {
1270b40e9013SMahesh Ravishankar     return {};
12712c58cde0SMahesh Ravishankar   }
12722c58cde0SMahesh Ravishankar 
1273b40e9013SMahesh Ravishankar   // Compute all the loops with the reduction iterator types.
1274c54bc8bdSOleg Shyshkov   SmallVector<unsigned> reductionDims;
1275c54bc8bdSOleg Shyshkov   genericOp.getReductionDims(reductionDims);
12762c58cde0SMahesh Ravishankar 
1277b40e9013SMahesh Ravishankar   llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
12781227b8abSOleg Shyshkov   AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
1279c54bc8bdSOleg Shyshkov   auto iteratorTypes = genericOp.getIteratorTypesArray();
1280b40e9013SMahesh Ravishankar   SmallVector<ReassociationIndices> iterationSpaceReassociation;
1281b40e9013SMahesh Ravishankar   for (ReassociationIndicesRef foldedRangeDims : reassociation) {
1282b40e9013SMahesh Ravishankar     assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
1283b40e9013SMahesh Ravishankar 
1284b40e9013SMahesh Ravishankar     // Ignore dims that are not folded.
1285b40e9013SMahesh Ravishankar     if (foldedRangeDims.size() == 1)
1286b40e9013SMahesh Ravishankar       continue;
1287b40e9013SMahesh Ravishankar 
1288b40e9013SMahesh Ravishankar     ReassociationIndices foldedIterationSpaceDims =
1289b40e9013SMahesh Ravishankar         getDomainReassociation(indexingMap, foldedRangeDims);
1290b40e9013SMahesh Ravishankar 
1291b40e9013SMahesh Ravishankar     // Check that the folded iteration dims do not contain already processed
1292b40e9013SMahesh Ravishankar     // dims.
1293b40e9013SMahesh Ravishankar     if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1294b40e9013SMahesh Ravishankar           return processedIterationDims.count(dim);
12952c58cde0SMahesh Ravishankar         }))
1296b40e9013SMahesh Ravishankar       continue;
1297b40e9013SMahesh Ravishankar 
1298b40e9013SMahesh Ravishankar     // Check that all folded iterator types are all parallel or all reductions.
1299e6598b05SOleg Shyshkov     utils::IteratorType startIteratorType =
1300e6598b05SOleg Shyshkov         iteratorTypes[foldedIterationSpaceDims[0]];
1301b40e9013SMahesh Ravishankar     if (!isParallelIterator(startIteratorType) &&
1302b40e9013SMahesh Ravishankar         !isReductionIterator(startIteratorType))
1303b40e9013SMahesh Ravishankar       continue;
1304b40e9013SMahesh Ravishankar     if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
1305b40e9013SMahesh Ravishankar           return iteratorTypes[dim] != startIteratorType;
1306b40e9013SMahesh Ravishankar         }))
1307b40e9013SMahesh Ravishankar       continue;
1308b40e9013SMahesh Ravishankar 
1309b40e9013SMahesh Ravishankar     // If the folded dimensions correspond to a "reduction" iterator type,
1310b40e9013SMahesh Ravishankar     // the folded dimensions need to be "in-order". Strictly speaking this is
1311b40e9013SMahesh Ravishankar     // not necessary, for reductions that are associative and commutative,  but
1312b40e9013SMahesh Ravishankar     // using a more strict definition of reduction for now.
1313b40e9013SMahesh Ravishankar     if (isReductionIterator(startIteratorType)) {
1314b40e9013SMahesh Ravishankar       bool isContiguous = false;
13156120bd47SMehdi Amini       for (const auto &startDim : llvm::enumerate(reductionDims)) {
1316b40e9013SMahesh Ravishankar         // Move window in `reductionDims` to start of the folded iteration dims.
1317b40e9013SMahesh Ravishankar         if (startDim.value() != foldedIterationSpaceDims[0])
1318b40e9013SMahesh Ravishankar           continue;
1319b40e9013SMahesh Ravishankar         // If sizes doesnt match, trivial not contiguous. This condition should
1320b40e9013SMahesh Ravishankar         // not be hit.
1321b40e9013SMahesh Ravishankar         if (startDim.index() + foldedIterationSpaceDims.size() >
1322b40e9013SMahesh Ravishankar             reductionDims.size())
1323b40e9013SMahesh Ravishankar           break;
1324b40e9013SMahesh Ravishankar         // Check that the contiguity is maintained.
1325b40e9013SMahesh Ravishankar         isContiguous = true;
13266120bd47SMehdi Amini         for (const auto &foldedDim :
13276120bd47SMehdi Amini              llvm::enumerate(foldedIterationSpaceDims)) {
1328b40e9013SMahesh Ravishankar           if (reductionDims[foldedDim.index() + startDim.index()] !=
1329b40e9013SMahesh Ravishankar               foldedDim.value()) {
1330b40e9013SMahesh Ravishankar             isContiguous = false;
1331b40e9013SMahesh Ravishankar             break;
13322c58cde0SMahesh Ravishankar           }
1333b40e9013SMahesh Ravishankar         }
1334b40e9013SMahesh Ravishankar         break;
1335b40e9013SMahesh Ravishankar       }
1336b40e9013SMahesh Ravishankar       if (!isContiguous)
1337b40e9013SMahesh Ravishankar         continue;
1338b40e9013SMahesh Ravishankar     }
1339b40e9013SMahesh Ravishankar 
1340b40e9013SMahesh Ravishankar     // Check that the sequence is preserved in all indexing maps.
1341d2c0572bSJacques Pienaar     if (llvm::any_of(genericOp.getIndexingMapsArray(),
1342d2c0572bSJacques Pienaar                      [&](AffineMap indexingMap) {
1343d2c0572bSJacques Pienaar                        return !isDimSequencePreserved(indexingMap,
1344d2c0572bSJacques Pienaar                                                       foldedIterationSpaceDims);
1345b40e9013SMahesh Ravishankar                      }))
1346b40e9013SMahesh Ravishankar       continue;
1347b40e9013SMahesh Ravishankar 
1348b40e9013SMahesh Ravishankar     processedIterationDims.insert(foldedIterationSpaceDims.begin(),
1349b40e9013SMahesh Ravishankar                                   foldedIterationSpaceDims.end());
1350b40e9013SMahesh Ravishankar     iterationSpaceReassociation.emplace_back(
1351b40e9013SMahesh Ravishankar         std::move(foldedIterationSpaceDims));
1352b40e9013SMahesh Ravishankar   }
1353b40e9013SMahesh Ravishankar 
1354b40e9013SMahesh Ravishankar   return iterationSpaceReassociation;
13552c58cde0SMahesh Ravishankar }
13562c58cde0SMahesh Ravishankar 
13572c58cde0SMahesh Ravishankar /// Helper class to carry state while collapsing the `linalg.generic` op.
13582c58cde0SMahesh Ravishankar namespace {
13592c58cde0SMahesh Ravishankar class CollapsingInfo {
13602c58cde0SMahesh Ravishankar public:
1361b40e9013SMahesh Ravishankar   LogicalResult initialize(unsigned origNumLoops,
1362b40e9013SMahesh Ravishankar                            ArrayRef<ReassociationIndices> foldedIterationDims) {
1363b40e9013SMahesh Ravishankar     llvm::SmallDenseSet<int64_t, 4> processedDims;
1364b40e9013SMahesh Ravishankar     // Find all the dims that are folded.
1365b40e9013SMahesh Ravishankar     for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
1366b40e9013SMahesh Ravishankar       if (foldedIterationDim.empty())
1367b40e9013SMahesh Ravishankar         continue;
1368b40e9013SMahesh Ravishankar       // If the folded dims contain dims already folded, that's illegal
1369b40e9013SMahesh Ravishankar       // specification. Repetition within a list is also illegal.
1370b40e9013SMahesh Ravishankar       for (auto dim : foldedIterationDim) {
1371b40e9013SMahesh Ravishankar         if (dim >= origNumLoops)
1372b40e9013SMahesh Ravishankar           return failure();
1373b40e9013SMahesh Ravishankar         if (processedDims.count(dim))
1374b40e9013SMahesh Ravishankar           return failure();
1375b40e9013SMahesh Ravishankar         processedDims.insert(dim);
13762c58cde0SMahesh Ravishankar       }
1377b40e9013SMahesh Ravishankar       collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
1378b40e9013SMahesh Ravishankar                                                    foldedIterationDim.end());
1379b40e9013SMahesh Ravishankar     }
1380b40e9013SMahesh Ravishankar     if (processedDims.size() > origNumLoops)
1381b40e9013SMahesh Ravishankar       return failure();
1382b40e9013SMahesh Ravishankar 
1383b40e9013SMahesh Ravishankar     // Add all the preserved dims of the original op as single
1384b40e9013SMahesh Ravishankar     // elements to `collapsedOpToOrigOpIterationDim`.
1385b40e9013SMahesh Ravishankar     for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
1386b40e9013SMahesh Ravishankar       if (processedDims.count(dim))
1387b40e9013SMahesh Ravishankar         continue;
1388b40e9013SMahesh Ravishankar       collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
13892c58cde0SMahesh Ravishankar     }
13902c58cde0SMahesh Ravishankar 
1391b40e9013SMahesh Ravishankar     llvm::sort(collapsedOpToOrigOpIterationDim,
1392b40e9013SMahesh Ravishankar                [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
1393b40e9013SMahesh Ravishankar                  return lhs[0] < rhs[0];
1394b40e9013SMahesh Ravishankar                });
1395b40e9013SMahesh Ravishankar     origOpToCollapsedOpIterationDim.resize(origNumLoops);
13966120bd47SMehdi Amini     for (const auto &foldedDims :
13976120bd47SMehdi Amini          llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
13986120bd47SMehdi Amini       for (const auto &dim : enumerate(foldedDims.value()))
1399b40e9013SMahesh Ravishankar         origOpToCollapsedOpIterationDim[dim.value()] =
1400b40e9013SMahesh Ravishankar             std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
1401b40e9013SMahesh Ravishankar     }
1402b40e9013SMahesh Ravishankar     return success();
14032c58cde0SMahesh Ravishankar   }
14042c58cde0SMahesh Ravishankar 
1405b40e9013SMahesh Ravishankar   /// Return mapping from collapsed loop domain to original loop domain.
1406b40e9013SMahesh Ravishankar   ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
1407b40e9013SMahesh Ravishankar     return collapsedOpToOrigOpIterationDim;
14082c58cde0SMahesh Ravishankar   }
14092c58cde0SMahesh Ravishankar 
1410b40e9013SMahesh Ravishankar   /// Return mapping from original loop domain to collapsed loop domain. The
1411b40e9013SMahesh Ravishankar   /// mapping is a pair. First value is the dimension in the collapsed loop that
1412b40e9013SMahesh Ravishankar   /// the original loop is mapped to. Second is the relative position in folded
1413b40e9013SMahesh Ravishankar   /// list of this domain. For example if the original loop domain is 3D, and
1414b40e9013SMahesh Ravishankar   /// the collapsed loop domain is folding all of it, i.e.
1415b40e9013SMahesh Ravishankar   ///
1416b40e9013SMahesh Ravishankar   /// ```
1417b40e9013SMahesh Ravishankar   /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
1418b40e9013SMahesh Ravishankar   /// ```
1419b40e9013SMahesh Ravishankar   ///
1420b40e9013SMahesh Ravishankar   /// then
1421b40e9013SMahesh Ravishankar   ///
1422b40e9013SMahesh Ravishankar   /// ```
1423b40e9013SMahesh Ravishankar   ///  origOpToCollapsedOpMapping[0] = {0, 0};
1424b40e9013SMahesh Ravishankar   ///  origOpToCollapsedOpMapping[1] = {0, 1};
1425b40e9013SMahesh Ravishankar   ///  origOpToCollapsedOpMapping[2] = {0, 2};
1426b40e9013SMahesh Ravishankar   ///  origOpToCollapsedOpMapping[3] = {1, 0};
1427b40e9013SMahesh Ravishankar   ///  origOpToCollapsedOpMapping[4] = {1, 1};
1428b40e9013SMahesh Ravishankar   /// ```
1429b40e9013SMahesh Ravishankar   ///
1430b40e9013SMahesh Ravishankar   ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
1431b40e9013SMahesh Ravishankar     return origOpToCollapsedOpIterationDim;
14322c58cde0SMahesh Ravishankar   }
14332c58cde0SMahesh Ravishankar 
1434b40e9013SMahesh Ravishankar   /// Return the collapsed op iteration domain rank.
1435b40e9013SMahesh Ravishankar   unsigned getCollapsedOpIterationRank() const {
1436b40e9013SMahesh Ravishankar     return collapsedOpToOrigOpIterationDim.size();
14372c58cde0SMahesh Ravishankar   }
14382c58cde0SMahesh Ravishankar 
14392c58cde0SMahesh Ravishankar private:
1440b40e9013SMahesh Ravishankar   /// Map from the iteration domain index in collapsed op to the iteration
1441b40e9013SMahesh Ravishankar   /// domain indices in the original op.
1442b40e9013SMahesh Ravishankar   SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
14432c58cde0SMahesh Ravishankar 
1444b40e9013SMahesh Ravishankar   /// Map from iteration domain index in the original op to the iteration domain
1445b40e9013SMahesh Ravishankar   /// index in the collapsed op.
1446b40e9013SMahesh Ravishankar   SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
14472c58cde0SMahesh Ravishankar };
14482c58cde0SMahesh Ravishankar } // namespace
14492c58cde0SMahesh Ravishankar 
14502c58cde0SMahesh Ravishankar /// Get the iterator types for the collapsed operation given the original
14512c58cde0SMahesh Ravishankar /// iterator types and collapsed dimensions.
1452e6598b05SOleg Shyshkov static SmallVector<utils::IteratorType>
1453e6598b05SOleg Shyshkov getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
1454b40e9013SMahesh Ravishankar                             const CollapsingInfo &collapsingInfo) {
1455e6598b05SOleg Shyshkov   SmallVector<utils::IteratorType> collapsedIteratorTypes;
14562c58cde0SMahesh Ravishankar   for (ReassociationIndicesRef foldedIterDims :
1457b40e9013SMahesh Ravishankar        collapsingInfo.getCollapsedOpToOrigOpMapping()) {
14582c58cde0SMahesh Ravishankar     assert(!foldedIterDims.empty() &&
14592c58cde0SMahesh Ravishankar            "reassociation indices expected to have non-empty sets");
14602c58cde0SMahesh Ravishankar     // Just pick the iterator type of the first folded dim. Pre-condition checks
14612c58cde0SMahesh Ravishankar     // expected to have checked that iterator types of all folded dimensions are
14622c58cde0SMahesh Ravishankar     // the same.
1463e6598b05SOleg Shyshkov     collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
14642c58cde0SMahesh Ravishankar   }
14652c58cde0SMahesh Ravishankar   return collapsedIteratorTypes;
14662c58cde0SMahesh Ravishankar }
14672c58cde0SMahesh Ravishankar 
14682c58cde0SMahesh Ravishankar /// Compute the indexing map in the collapsed op that corresponds to the given
14692c58cde0SMahesh Ravishankar /// `indexingMap` of the original operation.
1470b40e9013SMahesh Ravishankar static AffineMap
1471b40e9013SMahesh Ravishankar getCollapsedOpIndexingMap(AffineMap indexingMap,
1472b40e9013SMahesh Ravishankar                           const CollapsingInfo &collapsingInfo) {
14732c58cde0SMahesh Ravishankar   MLIRContext *context = indexingMap.getContext();
14742c58cde0SMahesh Ravishankar   assert(indexingMap.isProjectedPermutation() &&
14752c58cde0SMahesh Ravishankar          "expected indexing map to be projected permutation");
14762c58cde0SMahesh Ravishankar   SmallVector<AffineExpr> resultExprs;
1477b40e9013SMahesh Ravishankar   auto origOpToCollapsedOpMapping =
1478b40e9013SMahesh Ravishankar       collapsingInfo.getOrigOpToCollapsedOpMapping();
14792c58cde0SMahesh Ravishankar   for (auto expr : indexingMap.getResults()) {
14801609f1c2Slong.chen     unsigned dim = cast<AffineDimExpr>(expr).getPosition();
1481b40e9013SMahesh Ravishankar     // If the dim is not the first of the collapsed dim, do nothing.
1482b40e9013SMahesh Ravishankar     if (origOpToCollapsedOpMapping[dim].second != 0)
1483b40e9013SMahesh Ravishankar       continue;
1484b40e9013SMahesh Ravishankar     // The next n-dims are guaranteed to be collapsed. So just use the
1485b40e9013SMahesh Ravishankar     // iteration dimension of the collapsed op.
1486b40e9013SMahesh Ravishankar     resultExprs.push_back(
1487b40e9013SMahesh Ravishankar         getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
14882c58cde0SMahesh Ravishankar   }
1489b40e9013SMahesh Ravishankar   return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
14902c58cde0SMahesh Ravishankar                         resultExprs, context);
14912c58cde0SMahesh Ravishankar }
14922c58cde0SMahesh Ravishankar 
14932c58cde0SMahesh Ravishankar /// Return the `reassociation` indices to use to collapse the operand when the
14942c58cde0SMahesh Ravishankar /// iteration space of a generic op is collapsed.
14952c58cde0SMahesh Ravishankar static SmallVector<ReassociationIndices>
1496b40e9013SMahesh Ravishankar getOperandReassociation(AffineMap indexingMap,
1497b40e9013SMahesh Ravishankar                         const CollapsingInfo &collapsingInfo) {
14982c58cde0SMahesh Ravishankar   unsigned counter = 0;
14992c58cde0SMahesh Ravishankar   SmallVector<ReassociationIndices> operandReassociation;
1500b40e9013SMahesh Ravishankar   auto origOpToCollapsedOpMapping =
1501b40e9013SMahesh Ravishankar       collapsingInfo.getOrigOpToCollapsedOpMapping();
1502b40e9013SMahesh Ravishankar   auto collapsedOpToOrigOpMapping =
1503b40e9013SMahesh Ravishankar       collapsingInfo.getCollapsedOpToOrigOpMapping();
1504b40e9013SMahesh Ravishankar   while (counter < indexingMap.getNumResults()) {
1505b40e9013SMahesh Ravishankar     unsigned dim =
15061609f1c2Slong.chen         cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition();
1507b40e9013SMahesh Ravishankar     // This is the start of a collapsed dimensions of the iteration that
1508b40e9013SMahesh Ravishankar     // is gauranteed to be preserved in the indexing map. The number of folded
1509b40e9013SMahesh Ravishankar     // dims is obtained from the collapsed op to original op mapping.
15102c58cde0SMahesh Ravishankar     unsigned numFoldedDims =
1511b40e9013SMahesh Ravishankar         collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
1512b40e9013SMahesh Ravishankar             .size();
151312cc8e73SGuray Ozen     if (origOpToCollapsedOpMapping[dim].second == 0) {
15142c58cde0SMahesh Ravishankar       auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
15152c58cde0SMahesh Ravishankar       operandReassociation.emplace_back(range.begin(), range.end());
15162c58cde0SMahesh Ravishankar     }
151712cc8e73SGuray Ozen     counter += numFoldedDims;
15182c58cde0SMahesh Ravishankar   }
15192c58cde0SMahesh Ravishankar   return operandReassociation;
15202c58cde0SMahesh Ravishankar }
15212c58cde0SMahesh Ravishankar 
15222c58cde0SMahesh Ravishankar /// Get the new value to use for a given `OpOperand` in the collapsed operation.
15235c3ed392SAviad Cohen static Value getCollapsedOpOperand(Location loc, LinalgOp op,
15242c58cde0SMahesh Ravishankar                                    OpOperand *opOperand,
1525b40e9013SMahesh Ravishankar                                    const CollapsingInfo &collapsingInfo,
15262c58cde0SMahesh Ravishankar                                    OpBuilder &builder) {
15275c3ed392SAviad Cohen   AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
15282c58cde0SMahesh Ravishankar   SmallVector<ReassociationIndices> operandReassociation =
15292c58cde0SMahesh Ravishankar       getOperandReassociation(indexingMap, collapsingInfo);
15302c58cde0SMahesh Ravishankar 
15315c3ed392SAviad Cohen   // If the number of entries in the reassociation for the operand is same as
15325c3ed392SAviad Cohen   // the number of results of the indexing map, then nothing to do for this
15335c3ed392SAviad Cohen   // operand.
15342c58cde0SMahesh Ravishankar   Value operand = opOperand->get();
15352c58cde0SMahesh Ravishankar   if (operandReassociation.size() == indexingMap.getNumResults())
15362c58cde0SMahesh Ravishankar     return operand;
15372c58cde0SMahesh Ravishankar 
15382c58cde0SMahesh Ravishankar   // Insert a reshape to collapse the dimensions.
1539d4ae7ee6SAviad Cohen   if (isa<MemRefType>(operand.getType())) {
1540d4ae7ee6SAviad Cohen     return builder
1541d4ae7ee6SAviad Cohen         .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
1542d4ae7ee6SAviad Cohen         .getResult();
154346ce993dSMehdi Amini   }
1544d4ae7ee6SAviad Cohen   return builder
1545d4ae7ee6SAviad Cohen       .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
1546d4ae7ee6SAviad Cohen       .getResult();
1547d4ae7ee6SAviad Cohen }
15482c58cde0SMahesh Ravishankar 
15492c58cde0SMahesh Ravishankar /// Modify the `linalg.index` operations in the original generic op, to its
15502c58cde0SMahesh Ravishankar /// value in the collapsed operation.
15512c58cde0SMahesh Ravishankar void generateCollapsedIndexingRegion(Location loc, Block *block,
1552b40e9013SMahesh Ravishankar                                      const CollapsingInfo &collapsingInfo,
15532c58cde0SMahesh Ravishankar                                      ValueRange loopRange,
1554135977c9SGuray Ozen                                      RewriterBase &rewriter) {
15552c58cde0SMahesh Ravishankar   OpBuilder::InsertionGuard g(rewriter);
15562c58cde0SMahesh Ravishankar   rewriter.setInsertionPointToStart(block);
15572c58cde0SMahesh Ravishankar 
15582c58cde0SMahesh Ravishankar   // Collect all the original index ops.
15592c58cde0SMahesh Ravishankar   auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
15602c58cde0SMahesh Ravishankar 
15612c58cde0SMahesh Ravishankar   // For each folded dimension list resolve the original induction variable
15622c58cde0SMahesh Ravishankar   // values in terms of the folded dimension induction variable.
15632c58cde0SMahesh Ravishankar   //   i_{folded} = (i_0 * d1 + i1) * d2 + i2.
15642c58cde0SMahesh Ravishankar   // can be inverted to
15652c58cde0SMahesh Ravishankar   //   i2 = i_{folded} % d2
15662c58cde0SMahesh Ravishankar   //   i1 = (i_{folded} / d2) % d1
15672c58cde0SMahesh Ravishankar   //   i0 = i_{folded} / (d1 * d2)
15682c58cde0SMahesh Ravishankar   llvm::DenseMap<unsigned, Value> indexReplacementVals;
1569a0a76804SJakub Kuderski   for (auto foldedDims :
1570b40e9013SMahesh Ravishankar        enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
15712c58cde0SMahesh Ravishankar     ReassociationIndicesRef foldedDimsRef(foldedDims.value());
15722c58cde0SMahesh Ravishankar     Value newIndexVal =
15732c58cde0SMahesh Ravishankar         rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
15742c58cde0SMahesh Ravishankar     for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
15752c58cde0SMahesh Ravishankar       indexReplacementVals[dim] =
1576*1f5335c1SMaheshRavishankar           rewriter.create<arith::RemSIOp>(loc, newIndexVal, loopRange[dim]);
15772c58cde0SMahesh Ravishankar       newIndexVal =
1578*1f5335c1SMaheshRavishankar           rewriter.create<arith::DivSIOp>(loc, newIndexVal, loopRange[dim]);
15792c58cde0SMahesh Ravishankar     }
15802c58cde0SMahesh Ravishankar     indexReplacementVals[foldedDims.value().front()] = newIndexVal;
15812c58cde0SMahesh Ravishankar   }
15822c58cde0SMahesh Ravishankar 
15832c58cde0SMahesh Ravishankar   for (auto indexOp : indexOps) {
1584d3b3f765SJacques Pienaar     auto dim = indexOp.getDim();
15852c58cde0SMahesh Ravishankar     rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
15862c58cde0SMahesh Ravishankar   }
15872c58cde0SMahesh Ravishankar }
15882c58cde0SMahesh Ravishankar 
1589b6f4dd9eSsrcarroll void collapseOperandsAndResults(LinalgOp op,
15905c3ed392SAviad Cohen                                 const CollapsingInfo &collapsingInfo,
1591b6f4dd9eSsrcarroll                                 RewriterBase &rewriter,
1592b6f4dd9eSsrcarroll                                 SmallVectorImpl<Value> &inputOperands,
1593b6f4dd9eSsrcarroll                                 SmallVectorImpl<Value> &outputOperands,
1594b6f4dd9eSsrcarroll                                 SmallVectorImpl<Type> &resultTypes) {
15955c3ed392SAviad Cohen   Location loc = op->getLoc();
1596b6f4dd9eSsrcarroll   inputOperands =
15975c3ed392SAviad Cohen       llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
15985c3ed392SAviad Cohen         return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
15995c3ed392SAviad Cohen                                      rewriter);
16005c3ed392SAviad Cohen       });
16015c3ed392SAviad Cohen 
16025c3ed392SAviad Cohen   // Get the output operands and result types.
16035c3ed392SAviad Cohen   resultTypes.reserve(op.getNumDpsInits());
16045c3ed392SAviad Cohen   outputOperands.reserve(op.getNumDpsInits());
16055c3ed392SAviad Cohen   for (OpOperand &output : op.getDpsInitsMutable()) {
16065c3ed392SAviad Cohen     Value newOutput =
16075c3ed392SAviad Cohen         getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
16085c3ed392SAviad Cohen     outputOperands.push_back(newOutput);
16095c3ed392SAviad Cohen     // If the op has "buffer semantics", then the init operands are ranked
16105c3ed392SAviad Cohen     // memrefs and the op has no results.
16110a8e3dd4SMatthias Springer     if (!op.hasPureBufferSemantics())
16125c3ed392SAviad Cohen       resultTypes.push_back(newOutput.getType());
16135c3ed392SAviad Cohen   }
16145c3ed392SAviad Cohen }
16155c3ed392SAviad Cohen 
1616b6f4dd9eSsrcarroll /// Clone a `LinalgOp` to a collapsed version of same name
1617b6f4dd9eSsrcarroll template <typename OpTy>
1618b6f4dd9eSsrcarroll OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
1619b6f4dd9eSsrcarroll                         const CollapsingInfo &collapsingInfo) {
1620b6f4dd9eSsrcarroll   return nullptr;
1621b6f4dd9eSsrcarroll }
16225c3ed392SAviad Cohen 
1623b6f4dd9eSsrcarroll /// Collapse any `LinalgOp` that does not require any specialization such as
1624b6f4dd9eSsrcarroll /// indexing_maps, iterator_types, etc.
1625b6f4dd9eSsrcarroll template <>
1626b6f4dd9eSsrcarroll LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
1627b6f4dd9eSsrcarroll                                       const CollapsingInfo &collapsingInfo) {
1628b6f4dd9eSsrcarroll   SmallVector<Value> inputOperands, outputOperands;
1629b6f4dd9eSsrcarroll   SmallVector<Type> resultTypes;
1630b6f4dd9eSsrcarroll   collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1631b6f4dd9eSsrcarroll                              outputOperands, resultTypes);
1632ccc02563SAviad Cohen 
1633ccc02563SAviad Cohen   return clone(
1634b6f4dd9eSsrcarroll       rewriter, origOp, resultTypes,
1635ccc02563SAviad Cohen       llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
1636b6f4dd9eSsrcarroll }
1637b6f4dd9eSsrcarroll 
1638b6f4dd9eSsrcarroll /// Collapse a `GenericOp`
1639b6f4dd9eSsrcarroll template <>
1640b6f4dd9eSsrcarroll GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
1641b6f4dd9eSsrcarroll                                         GenericOp origOp,
1642b6f4dd9eSsrcarroll                                         const CollapsingInfo &collapsingInfo) {
1643b6f4dd9eSsrcarroll   SmallVector<Value> inputOperands, outputOperands;
1644b6f4dd9eSsrcarroll   SmallVector<Type> resultTypes;
1645b6f4dd9eSsrcarroll   collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
1646b6f4dd9eSsrcarroll                              outputOperands, resultTypes);
1647b6f4dd9eSsrcarroll   SmallVector<AffineMap> indexingMaps(
1648b6f4dd9eSsrcarroll       llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
16495c3ed392SAviad Cohen         return getCollapsedOpIndexingMap(map, collapsingInfo);
1650b6f4dd9eSsrcarroll       }));
16515c3ed392SAviad Cohen 
1652b6f4dd9eSsrcarroll   SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
1653b6f4dd9eSsrcarroll       origOp.getIteratorTypesArray(), collapsingInfo));
1654b6f4dd9eSsrcarroll 
1655b6f4dd9eSsrcarroll   GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
1656b6f4dd9eSsrcarroll       origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
16575c3ed392SAviad Cohen       iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
1658b6f4dd9eSsrcarroll   Block *origOpBlock = &origOp->getRegion(0).front();
16595c3ed392SAviad Cohen   Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
16605c3ed392SAviad Cohen   rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
16615c3ed392SAviad Cohen                        collapsedOpBlock->getArguments());
16625c3ed392SAviad Cohen   return collapsedOp;
16635c3ed392SAviad Cohen }
16645c3ed392SAviad Cohen 
1665b6f4dd9eSsrcarroll LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
16665c3ed392SAviad Cohen                            RewriterBase &rewriter) {
1667b6f4dd9eSsrcarroll   if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
1668b6f4dd9eSsrcarroll     return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
1669b6f4dd9eSsrcarroll   } else {
1670b6f4dd9eSsrcarroll     return cloneToCollapsedOp(rewriter, op, collapsingInfo);
1671b6f4dd9eSsrcarroll   }
1672b6f4dd9eSsrcarroll }
16735c3ed392SAviad Cohen 
1674b6f4dd9eSsrcarroll /// Implementation of fusion with reshape operation by collapsing dimensions.
1675b6f4dd9eSsrcarroll FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
1676b6f4dd9eSsrcarroll     LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
1677b6f4dd9eSsrcarroll     RewriterBase &rewriter) {
1678b40e9013SMahesh Ravishankar   // Bail on trivial no-op cases.
16795c3ed392SAviad Cohen   if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
1680b40e9013SMahesh Ravishankar       llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
1681b40e9013SMahesh Ravishankar         return foldedDims.size() <= 1;
1682b40e9013SMahesh Ravishankar       }))
1683b40e9013SMahesh Ravishankar     return failure();
16842c58cde0SMahesh Ravishankar 
16850a8e3dd4SMatthias Springer   bool hasPureBufferSemantics = op.hasPureBufferSemantics();
16860a8e3dd4SMatthias Springer   if (hasPureBufferSemantics &&
16875c3ed392SAviad Cohen       !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
1688d4ae7ee6SAviad Cohen         MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1689d4ae7ee6SAviad Cohen         if (!memRefToCollapse)
1690d4ae7ee6SAviad Cohen           return true;
1691d4ae7ee6SAviad Cohen 
1692d4ae7ee6SAviad Cohen         return memref::CollapseShapeOp::isGuaranteedCollapsible(
1693d4ae7ee6SAviad Cohen             memRefToCollapse, foldedIterationDims);
1694d4ae7ee6SAviad Cohen       }))
16955c3ed392SAviad Cohen     return rewriter.notifyMatchFailure(op,
1696d4ae7ee6SAviad Cohen                                        "memref is not guaranteed collapsible");
1697d4ae7ee6SAviad Cohen 
1698b40e9013SMahesh Ravishankar   CollapsingInfo collapsingInfo;
16995c3ed392SAviad Cohen   if (failed(
17005c3ed392SAviad Cohen           collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1701b40e9013SMahesh Ravishankar     return rewriter.notifyMatchFailure(
17025c3ed392SAviad Cohen         op, "illegal to collapse specified dimensions");
1703b40e9013SMahesh Ravishankar   }
17042c58cde0SMahesh Ravishankar 
170570e99f38SAlex Zinenko   // Bail on non-canonical ranges.
1706b6f4dd9eSsrcarroll   SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
170770e99f38SAlex Zinenko   auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
170868f58812STres Popp     if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
17095550c821STres Popp       return cast<IntegerAttr>(attr).getInt() == value;
171070e99f38SAlex Zinenko     llvm::APInt actual;
17114f279a57SKazu Hirata     return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) &&
171270e99f38SAlex Zinenko            actual.getSExtValue() == value;
171370e99f38SAlex Zinenko   };
171470e99f38SAlex Zinenko   if (!llvm::all_of(loopRanges, [&](Range range) {
171570e99f38SAlex Zinenko         return opFoldIsConstantValue(range.offset, 0) &&
171670e99f38SAlex Zinenko                opFoldIsConstantValue(range.stride, 1);
171770e99f38SAlex Zinenko       })) {
171870e99f38SAlex Zinenko     return rewriter.notifyMatchFailure(
17195c3ed392SAviad Cohen         op, "expected all loop ranges to have zero start and unit stride");
172070e99f38SAlex Zinenko   }
172170e99f38SAlex Zinenko 
1722b6f4dd9eSsrcarroll   LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
17232c58cde0SMahesh Ravishankar 
17245c3ed392SAviad Cohen   Location loc = op->getLoc();
17255c3ed392SAviad Cohen   if (collapsedOp.hasIndexSemantics()) {
17262c58cde0SMahesh Ravishankar     // Collect the loop range of the generic op.
17272c58cde0SMahesh Ravishankar     OpBuilder::InsertionGuard g(rewriter);
17285c3ed392SAviad Cohen     rewriter.setInsertionPoint(collapsedOp);
172970e99f38SAlex Zinenko     SmallVector<Value> loopBound =
17305c3ed392SAviad Cohen         llvm::map_to_vector(loopRanges, [&](Range range) {
17314bf84e43SAlexander Belyaev           return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
17325c3ed392SAviad Cohen         });
17335c3ed392SAviad Cohen     generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
17342c58cde0SMahesh Ravishankar                                     collapsingInfo, loopBound, rewriter);
17352c58cde0SMahesh Ravishankar   }
17362c58cde0SMahesh Ravishankar 
17372c58cde0SMahesh Ravishankar   // Insert expanding reshape for the result to get back the original result
17382c58cde0SMahesh Ravishankar   // type.
17392c58cde0SMahesh Ravishankar   SmallVector<Value> results;
17405c3ed392SAviad Cohen   for (const auto &originalResult : llvm::enumerate(op->getResults())) {
17415c3ed392SAviad Cohen     Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
17422c58cde0SMahesh Ravishankar     auto originalResultType =
17435550c821STres Popp         cast<ShapedType>(originalResult.value().getType());
17445550c821STres Popp     auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
17452c58cde0SMahesh Ravishankar     if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
17462c58cde0SMahesh Ravishankar       AffineMap indexingMap =
17475c3ed392SAviad Cohen           op.getIndexingMapMatchingResult(originalResult.value());
17482c58cde0SMahesh Ravishankar       SmallVector<ReassociationIndices> reassociation =
17492c58cde0SMahesh Ravishankar           getOperandReassociation(indexingMap, collapsingInfo);
175097069a86SGaurav Shukla       Value result;
1751d4ae7ee6SAviad Cohen       if (isa<MemRefType>(collapsedOpResult.getType())) {
175297069a86SGaurav Shukla         MemRefType expandShapeResultType = MemRefType::get(
175397069a86SGaurav Shukla             originalResultType.getShape(), originalResultType.getElementType());
175497069a86SGaurav Shukla         result = rewriter.create<memref::ExpandShapeOp>(
175597069a86SGaurav Shukla             loc, expandShapeResultType, collapsedOpResult, reassociation);
17562c58cde0SMahesh Ravishankar       } else {
175797069a86SGaurav Shukla         result = rewriter.create<tensor::ExpandShapeOp>(
17588c0341dfSMehdi Amini             loc, originalResultType, collapsedOpResult, reassociation);
17598c0341dfSMehdi Amini       }
176097069a86SGaurav Shukla       results.push_back(result);
17618c0341dfSMehdi Amini     } else {
17622c58cde0SMahesh Ravishankar       results.push_back(collapsedOpResult);
17632c58cde0SMahesh Ravishankar     }
17642c58cde0SMahesh Ravishankar   }
1765b6f4dd9eSsrcarroll   return CollapseResult{results, collapsedOp};
17662c58cde0SMahesh Ravishankar }
17672c58cde0SMahesh Ravishankar 
17682c58cde0SMahesh Ravishankar namespace {
17692c58cde0SMahesh Ravishankar 
17702c58cde0SMahesh Ravishankar /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
17712c58cde0SMahesh Ravishankar /// contracting dimensions of the loop.
17722c58cde0SMahesh Ravishankar class FoldWithProducerReshapeOpByCollapsing
17732c58cde0SMahesh Ravishankar     : public OpRewritePattern<GenericOp> {
17742c58cde0SMahesh Ravishankar public:
17752291705dSMahesh Ravishankar   FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
17762291705dSMahesh Ravishankar                                         ControlFusionFn foldReshapes,
17772c58cde0SMahesh Ravishankar                                         PatternBenefit benefit = 1)
17782c58cde0SMahesh Ravishankar       : OpRewritePattern<GenericOp>(context, benefit),
17792c58cde0SMahesh Ravishankar         controlFoldingReshapes(std::move(foldReshapes)) {}
17802c58cde0SMahesh Ravishankar 
17812c58cde0SMahesh Ravishankar   LogicalResult matchAndRewrite(GenericOp genericOp,
17822c58cde0SMahesh Ravishankar                                 PatternRewriter &rewriter) const override {
1783a7cccb9cSAlexander Belyaev     for (OpOperand &opOperand : genericOp->getOpOperands()) {
17842c58cde0SMahesh Ravishankar       tensor::ExpandShapeOp reshapeOp =
1785a7cccb9cSAlexander Belyaev           opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
17862c58cde0SMahesh Ravishankar       if (!reshapeOp)
17872c58cde0SMahesh Ravishankar         continue;
17882c58cde0SMahesh Ravishankar 
1789b40e9013SMahesh Ravishankar       SmallVector<ReassociationIndices> collapsableIterationDims =
1790a7cccb9cSAlexander Belyaev           getCollapsableIterationSpaceDims(genericOp, &opOperand,
1791b40e9013SMahesh Ravishankar                                            reshapeOp.getReassociationIndices());
1792b40e9013SMahesh Ravishankar       if (collapsableIterationDims.empty() ||
1793a7cccb9cSAlexander Belyaev           !controlFoldingReshapes(&opOperand)) {
17942c58cde0SMahesh Ravishankar         continue;
17952c58cde0SMahesh Ravishankar       }
17962c58cde0SMahesh Ravishankar 
1797b6f4dd9eSsrcarroll       std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
17985c3ed392SAviad Cohen           genericOp, collapsableIterationDims, rewriter);
1799b6f4dd9eSsrcarroll       if (!collapseResult) {
18002c58cde0SMahesh Ravishankar         return rewriter.notifyMatchFailure(
18012c58cde0SMahesh Ravishankar             genericOp, "failed to do the fusion by collapsing transformation");
18022c58cde0SMahesh Ravishankar       }
18032c58cde0SMahesh Ravishankar 
1804b6f4dd9eSsrcarroll       rewriter.replaceOp(genericOp, collapseResult->results);
18052c58cde0SMahesh Ravishankar       return success();
18062c58cde0SMahesh Ravishankar     }
18072c58cde0SMahesh Ravishankar     return failure();
18082c58cde0SMahesh Ravishankar   }
18092c58cde0SMahesh Ravishankar 
18102c58cde0SMahesh Ravishankar private:
18112291705dSMahesh Ravishankar   ControlFusionFn controlFoldingReshapes;
18122c58cde0SMahesh Ravishankar };
181383c65fbcSThomas Raoux 
1814c886d66dSMax191 class FoldPadWithProducerReshapeOpByCollapsing
1815c886d66dSMax191     : public OpRewritePattern<tensor::PadOp> {
1816c886d66dSMax191 public:
1817c886d66dSMax191   FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
1818c886d66dSMax191                                            ControlFusionFn foldReshapes,
1819c886d66dSMax191                                            PatternBenefit benefit = 1)
1820c886d66dSMax191       : OpRewritePattern<tensor::PadOp>(context, benefit),
1821c886d66dSMax191         controlFoldingReshapes(std::move(foldReshapes)) {}
1822c886d66dSMax191 
1823c886d66dSMax191   LogicalResult matchAndRewrite(tensor::PadOp padOp,
1824c886d66dSMax191                                 PatternRewriter &rewriter) const override {
1825c886d66dSMax191     tensor::ExpandShapeOp reshapeOp =
1826c886d66dSMax191         padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
1827c886d66dSMax191     if (!reshapeOp)
1828c886d66dSMax191       return failure();
1829c886d66dSMax191     if (!reshapeOp->hasOneUse())
1830c886d66dSMax191       return failure();
1831c886d66dSMax191 
1832c886d66dSMax191     if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
1833c886d66dSMax191       return rewriter.notifyMatchFailure(padOp,
1834c886d66dSMax191                                          "fusion blocked by control function");
1835c886d66dSMax191     }
1836c886d66dSMax191 
1837c886d66dSMax191     ArrayRef<int64_t> low = padOp.getStaticLow();
1838c886d66dSMax191     ArrayRef<int64_t> high = padOp.getStaticHigh();
1839c886d66dSMax191     SmallVector<ReassociationIndices> reassociations =
1840c886d66dSMax191         reshapeOp.getReassociationIndices();
1841c886d66dSMax191 
1842c886d66dSMax191     for (auto reInd : reassociations) {
1843c886d66dSMax191       if (reInd.size() == 1)
1844c886d66dSMax191         continue;
1845c886d66dSMax191       if (llvm::any_of(reInd, [&](int64_t ind) {
1846c886d66dSMax191             return low[ind] != 0 || high[ind] != 0;
1847c886d66dSMax191           })) {
1848c886d66dSMax191         return failure();
1849c886d66dSMax191       }
1850c886d66dSMax191     }
1851c886d66dSMax191 
1852c886d66dSMax191     SmallVector<OpFoldResult> newLow, newHigh;
1853c886d66dSMax191     RankedTensorType collapsedType = reshapeOp.getSrcType();
1854c886d66dSMax191     RankedTensorType paddedType = padOp.getResultType();
1855c886d66dSMax191     SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
1856c886d66dSMax191     SmallVector<OpFoldResult> expandedPaddedSizes(
1857c886d66dSMax191         getMixedValues(reshapeOp.getStaticOutputShape(),
1858c886d66dSMax191                        reshapeOp.getOutputShape(), rewriter));
1859c886d66dSMax191     AffineExpr d0, d1, d2;
1860c886d66dSMax191     bindDims(rewriter.getContext(), d0, d1, d2);
1861c886d66dSMax191     auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
1862c886d66dSMax191     Location loc = reshapeOp->getLoc();
1863c886d66dSMax191     for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
1864c886d66dSMax191       OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
1865c886d66dSMax191       OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
1866c886d66dSMax191       if (reInd.size() == 1) {
1867c886d66dSMax191         collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
1868c886d66dSMax191         OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
1869c886d66dSMax191             rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
1870c886d66dSMax191         expandedPaddedSizes[reInd[0]] = paddedSize;
1871c886d66dSMax191       }
1872c886d66dSMax191       newLow.push_back(l);
1873c886d66dSMax191       newHigh.push_back(h);
1874c886d66dSMax191     }
1875c886d66dSMax191 
1876c886d66dSMax191     RankedTensorType collapsedPaddedType =
1877c886d66dSMax191         paddedType.clone(collapsedPaddedShape);
1878c886d66dSMax191     auto newPadOp = rewriter.create<tensor::PadOp>(
1879c886d66dSMax191         loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
1880c886d66dSMax191         padOp.getConstantPaddingValue(), padOp.getNofold());
1881c886d66dSMax191 
1882c886d66dSMax191     rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1883c886d66dSMax191         padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
1884c886d66dSMax191         expandedPaddedSizes);
1885c886d66dSMax191 
1886c886d66dSMax191     return success();
1887c886d66dSMax191   }
1888c886d66dSMax191 
1889c886d66dSMax191 private:
1890c886d66dSMax191   ControlFusionFn controlFoldingReshapes;
1891c886d66dSMax191 };
1892c886d66dSMax191 
189383c65fbcSThomas Raoux /// Pattern to collapse dimensions.
18945c3ed392SAviad Cohen template <typename LinalgType>
18955c3ed392SAviad Cohen class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
189683c65fbcSThomas Raoux public:
189783c65fbcSThomas Raoux   CollapseLinalgDimensions(MLIRContext *context,
189883c65fbcSThomas Raoux                            GetCollapsableDimensionsFn collapseDimensions,
189983c65fbcSThomas Raoux                            PatternBenefit benefit = 1)
19005c3ed392SAviad Cohen       : OpRewritePattern<LinalgType>(context, benefit),
190183c65fbcSThomas Raoux         controlCollapseDimension(std::move(collapseDimensions)) {}
190283c65fbcSThomas Raoux 
19035c3ed392SAviad Cohen   LogicalResult matchAndRewrite(LinalgType op,
190483c65fbcSThomas Raoux                                 PatternRewriter &rewriter) const override {
190583c65fbcSThomas Raoux     SmallVector<ReassociationIndices> collapsableIterationDims =
19065c3ed392SAviad Cohen         controlCollapseDimension(op);
190783c65fbcSThomas Raoux     if (collapsableIterationDims.empty())
190883c65fbcSThomas Raoux       return failure();
190983c65fbcSThomas Raoux 
1910f12639d0SMahesh Ravishankar     // Check if the specified list of dimensions to collapse is a valid list.
19115c3ed392SAviad Cohen     if (!areDimSequencesPreserved(op.getIndexingMapsArray(),
1912f12639d0SMahesh Ravishankar                                   collapsableIterationDims)) {
1913f12639d0SMahesh Ravishankar       return rewriter.notifyMatchFailure(
19145c3ed392SAviad Cohen           op, "specified dimensions cannot be collapsed");
1915f12639d0SMahesh Ravishankar     }
1916f12639d0SMahesh Ravishankar 
1917b6f4dd9eSsrcarroll     std::optional<CollapseResult> collapseResult =
1918b6f4dd9eSsrcarroll         collapseOpIterationDims(op, collapsableIterationDims, rewriter);
1919b6f4dd9eSsrcarroll     if (!collapseResult) {
19205c3ed392SAviad Cohen       return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
192183c65fbcSThomas Raoux     }
1922b6f4dd9eSsrcarroll     rewriter.replaceOp(op, collapseResult->results);
192383c65fbcSThomas Raoux     return success();
192483c65fbcSThomas Raoux   }
192583c65fbcSThomas Raoux 
192683c65fbcSThomas Raoux private:
192783c65fbcSThomas Raoux   GetCollapsableDimensionsFn controlCollapseDimension;
192883c65fbcSThomas Raoux };
192983c65fbcSThomas Raoux 
19302c58cde0SMahesh Ravishankar } // namespace
19312c58cde0SMahesh Ravishankar 
19322c58cde0SMahesh Ravishankar //===---------------------------------------------------------------------===//
193332288d37SMahesh Ravishankar // Methods and patterns that fuse constants with linalg.generic operations.
193432288d37SMahesh Ravishankar //===---------------------------------------------------------------------===//
19355994201cSMaheshRavishankar 
193632288d37SMahesh Ravishankar namespace {
1937a40a08edSMaheshRavishankar /// Pattern to fold a generic op with a splat constant/scalar constant. Does not
1938a40a08edSMaheshRavishankar /// handle cases where the constant is not single-valued.
19394cd7ff67SLei Zhang class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
19405994201cSMaheshRavishankar public:
19412291705dSMahesh Ravishankar   FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
19422291705dSMahesh Ravishankar       : OpRewritePattern<GenericOp>(context, benefit) {}
19435994201cSMaheshRavishankar 
19445994201cSMaheshRavishankar   LogicalResult matchAndRewrite(GenericOp genericOp,
19455994201cSMaheshRavishankar                                 PatternRewriter &rewriter) const override {
19460a8e3dd4SMatthias Springer     if (!genericOp.hasPureTensorSemantics())
19475994201cSMaheshRavishankar       return failure();
1948b4db15a9SAlexander Belyaev     for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
19495994201cSMaheshRavishankar       Operation *def = opOperand->get().getDefiningOp();
1950e1795322SJeff Niu       TypedAttr constantAttr;
1951a40a08edSMaheshRavishankar       auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
1952a40a08edSMaheshRavishankar         {
1953a40a08edSMaheshRavishankar           DenseElementsAttr splatAttr;
1954a40a08edSMaheshRavishankar           if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
1955a40a08edSMaheshRavishankar               splatAttr.isSplat() &&
1956a40a08edSMaheshRavishankar               splatAttr.getType().getElementType().isIntOrFloat()) {
1957e1795322SJeff Niu             constantAttr = splatAttr.getSplatValue<TypedAttr>();
1958a40a08edSMaheshRavishankar             return true;
1959a40a08edSMaheshRavishankar           }
1960a40a08edSMaheshRavishankar         }
1961a40a08edSMaheshRavishankar         {
1962a40a08edSMaheshRavishankar           IntegerAttr intAttr;
1963a40a08edSMaheshRavishankar           if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
1964a40a08edSMaheshRavishankar             constantAttr = intAttr;
1965a40a08edSMaheshRavishankar             return true;
1966a40a08edSMaheshRavishankar           }
1967a40a08edSMaheshRavishankar         }
1968a40a08edSMaheshRavishankar         {
1969a40a08edSMaheshRavishankar           FloatAttr floatAttr;
1970a40a08edSMaheshRavishankar           if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
1971a40a08edSMaheshRavishankar             constantAttr = floatAttr;
1972a40a08edSMaheshRavishankar             return true;
1973a40a08edSMaheshRavishankar           }
1974a40a08edSMaheshRavishankar         }
1975a40a08edSMaheshRavishankar         return false;
1976a40a08edSMaheshRavishankar       };
1977a40a08edSMaheshRavishankar 
19785550c821STres Popp       auto resultValue = dyn_cast<OpResult>(opOperand->get());
19792291705dSMahesh Ravishankar       if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
19805994201cSMaheshRavishankar         continue;
19815994201cSMaheshRavishankar 
19825994201cSMaheshRavishankar       // The operands and the indexing_maps of the fused operation the same as
19835994201cSMaheshRavishankar       // the operands and indexing_maps of the generic operations with the
19845994201cSMaheshRavishankar       // values at the constant index dropped.
19855994201cSMaheshRavishankar       SmallVector<AffineMap> fusedIndexMaps;
19865994201cSMaheshRavishankar       SmallVector<Value> fusedOperands;
1987b983783dSGeoffrey Martin-Noble       SmallVector<Location> fusedLocs{genericOp.getLoc()};
1988a7cccb9cSAlexander Belyaev       fusedIndexMaps.reserve(genericOp->getNumOperands());
1989b4db15a9SAlexander Belyaev       fusedOperands.reserve(genericOp.getNumDpsInputs());
1990b4db15a9SAlexander Belyaev       fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
1991b4db15a9SAlexander Belyaev       for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
19925994201cSMaheshRavishankar         if (inputOperand == opOperand)
19935994201cSMaheshRavishankar           continue;
1994b983783dSGeoffrey Martin-Noble         Value inputValue = inputOperand->get();
19951227b8abSOleg Shyshkov         fusedIndexMaps.push_back(
19961227b8abSOleg Shyshkov             genericOp.getMatchingIndexingMap(inputOperand));
1997b983783dSGeoffrey Martin-Noble         fusedOperands.push_back(inputValue);
1998b983783dSGeoffrey Martin-Noble         fusedLocs.push_back(inputValue.getLoc());
19995994201cSMaheshRavishankar       }
20000b2197b0SMatthias Springer       for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
20011227b8abSOleg Shyshkov         fusedIndexMaps.push_back(
20020b2197b0SMatthias Springer             genericOp.getMatchingIndexingMap(&outputOperand));
20035994201cSMaheshRavishankar 
20045994201cSMaheshRavishankar       // Check if the operation shapes to loops map is computable.
200506514c55SIan Wood       if (!inversePermutation(
200606514c55SIan Wood               concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
20075994201cSMaheshRavishankar         return rewriter.notifyMatchFailure(
20085994201cSMaheshRavishankar             genericOp, "fused op loop bound computation failed");
20095994201cSMaheshRavishankar       }
20105994201cSMaheshRavishankar 
20115994201cSMaheshRavishankar       // Create a constant scalar value from the splat constant.
201200e3566dSRahul Kayaith       Value scalarConstant =
201300e3566dSRahul Kayaith           rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);
20145994201cSMaheshRavishankar 
2015a7cccb9cSAlexander Belyaev       SmallVector<Value> outputOperands = genericOp.getOutputs();
20165994201cSMaheshRavishankar       auto fusedOp = rewriter.create<GenericOp>(
2017b983783dSGeoffrey Martin-Noble           rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
20185994201cSMaheshRavishankar           /*inputs=*/fusedOperands,
20195994201cSMaheshRavishankar           /*outputs=*/outputOperands,
20205994201cSMaheshRavishankar           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
2021d3b3f765SJacques Pienaar           genericOp.getIteratorTypes(),
20225994201cSMaheshRavishankar           /*doc=*/nullptr,
20235994201cSMaheshRavishankar           /*library_call=*/nullptr);
20245994201cSMaheshRavishankar 
20255994201cSMaheshRavishankar       // Map the block argument corresponding to the replaced argument with the
20265994201cSMaheshRavishankar       // scalar constant.
20275994201cSMaheshRavishankar       Region &region = genericOp->getRegion(0);
20285994201cSMaheshRavishankar       Block &entryBlock = *region.begin();
20294d67b278SJeff Niu       IRMapping mapping;
20305994201cSMaheshRavishankar       mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
20315994201cSMaheshRavishankar                   scalarConstant);
20325994201cSMaheshRavishankar       Region &fusedRegion = fusedOp->getRegion(0);
20335994201cSMaheshRavishankar       rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
20345994201cSMaheshRavishankar                                  mapping);
20355994201cSMaheshRavishankar       rewriter.replaceOp(genericOp, fusedOp->getResults());
20365994201cSMaheshRavishankar       return success();
20375994201cSMaheshRavishankar     }
20385994201cSMaheshRavishankar     return failure();
20395994201cSMaheshRavishankar   }
20404cd7ff67SLei Zhang };
20414cd7ff67SLei Zhang 
20425994201cSMaheshRavishankar } // namespace
20435994201cSMaheshRavishankar 
204432288d37SMahesh Ravishankar //===---------------------------------------------------------------------===//
204532288d37SMahesh Ravishankar // Miscellaneous patterns that help fusion.
204632288d37SMahesh Ravishankar //===---------------------------------------------------------------------===//
20475994201cSMaheshRavishankar 
20485994201cSMaheshRavishankar namespace {
204981ca5aa4SMatthias Springer /// Forces `outs` operands of linalg operations to use `tensor.empty` if the
205081ca5aa4SMatthias Springer /// value of the `outs` operand is not used within the op.  This is only
20515994201cSMaheshRavishankar /// implemented for `linalg.generic` operations for now, but should hold for all
20525994201cSMaheshRavishankar /// linalg structured ops.
20535994201cSMaheshRavishankar struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
20545994201cSMaheshRavishankar   using OpRewritePattern<GenericOp>::OpRewritePattern;
20555994201cSMaheshRavishankar 
20565994201cSMaheshRavishankar   LogicalResult matchAndRewrite(GenericOp op,
20575994201cSMaheshRavishankar                                 PatternRewriter &rewriter) const override {
20585fcf907bSMatthias Springer     rewriter.startOpModification(op);
20595994201cSMaheshRavishankar     bool modifiedOutput = false;
20605994201cSMaheshRavishankar     Location loc = op.getLoc();
20610b2197b0SMatthias Springer     for (OpOperand &opOperand : op.getDpsInitsMutable()) {
20620b2197b0SMatthias Springer       if (!op.payloadUsesValueFromOperand(&opOperand)) {
20630b2197b0SMatthias Springer         Value operandVal = opOperand.get();
20645550c821STres Popp         auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
20655994201cSMaheshRavishankar         if (!operandType)
20665994201cSMaheshRavishankar           continue;
20675994201cSMaheshRavishankar 
2068c43e6274STim Harvey         // If outs is sparse, leave it to the sparsifier.
2069515c6170SAart Bik         if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
2070515c6170SAart Bik           continue;
2071515c6170SAart Bik 
207281ca5aa4SMatthias Springer         // If outs is already an `empty` operation, nothing to do.
207381ca5aa4SMatthias Springer         auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
20745994201cSMaheshRavishankar         if (definingOp)
20755994201cSMaheshRavishankar           continue;
20765994201cSMaheshRavishankar         modifiedOutput = true;
20776596b0ddSMatthias Springer         SmallVector<OpFoldResult> mixedSizes =
20786596b0ddSMatthias Springer             tensor::getMixedSizes(rewriter, loc, operandVal);
207981ca5aa4SMatthias Springer         Value emptyTensor = rewriter.create<tensor::EmptyOp>(
20806596b0ddSMatthias Springer             loc, mixedSizes, operandType.getElementType());
20810b2197b0SMatthias Springer         op->setOperand(opOperand.getOperandNumber(), emptyTensor);
20825994201cSMaheshRavishankar       }
20835994201cSMaheshRavishankar     }
20845994201cSMaheshRavishankar     if (!modifiedOutput) {
20855fcf907bSMatthias Springer       rewriter.cancelOpModification(op);
20865994201cSMaheshRavishankar       return failure();
20875994201cSMaheshRavishankar     }
20885fcf907bSMatthias Springer     rewriter.finalizeOpModification(op);
20895994201cSMaheshRavishankar     return success();
20905994201cSMaheshRavishankar   }
20915994201cSMaheshRavishankar };
20925994201cSMaheshRavishankar 
209301055ed1SNirvedh /// Fold linalg.fill into linalg.generic
209401055ed1SNirvedh struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
209501055ed1SNirvedh   using OpRewritePattern<GenericOp>::OpRewritePattern;
209601055ed1SNirvedh 
209701055ed1SNirvedh   LogicalResult matchAndRewrite(GenericOp genericOp,
209801055ed1SNirvedh                                 PatternRewriter &rewriter) const override {
20990a8e3dd4SMatthias Springer     if (!genericOp.hasPureTensorSemantics())
210001055ed1SNirvedh       return failure();
210101055ed1SNirvedh     bool fillFound = false;
2102d3b3f765SJacques Pienaar     Block &payload = genericOp.getRegion().front();
2103b4db15a9SAlexander Belyaev     for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
210401055ed1SNirvedh       if (!genericOp.payloadUsesValueFromOperand(opOperand))
210501055ed1SNirvedh         continue;
210601055ed1SNirvedh       FillOp fillOp = opOperand->get().getDefiningOp<FillOp>();
210701055ed1SNirvedh       if (!fillOp)
210801055ed1SNirvedh         continue;
210901055ed1SNirvedh       fillFound = true;
211004b449e1SPrashant Kumar       Value fillVal = fillOp.value();
211104b449e1SPrashant Kumar       auto resultType =
21125550c821STres Popp           cast<RankedTensorType>(fillOp.result().getType()).getElementType();
211304b449e1SPrashant Kumar       Value convertedVal =
211404b449e1SPrashant Kumar           convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
211504b449e1SPrashant Kumar                                /*isUnsignedCast =*/false);
2116b13a197eSMatthias Springer       rewriter.replaceAllUsesWith(
2117b13a197eSMatthias Springer           payload.getArgument(opOperand->getOperandNumber()), convertedVal);
211801055ed1SNirvedh     }
211901055ed1SNirvedh     return success(fillFound);
212001055ed1SNirvedh   }
212101055ed1SNirvedh };
212201055ed1SNirvedh } // namespace
21235994201cSMaheshRavishankar 
21245994201cSMaheshRavishankar void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
21255994201cSMaheshRavishankar     RewritePatternSet &patterns,
21262291705dSMahesh Ravishankar     const ControlFusionFn &controlFoldingReshapes) {
2127b546f434SMaheshRavishankar   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
2128b546f434SMaheshRavishankar                                                     controlFoldingReshapes);
2129c886d66dSMax191   patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
2130c886d66dSMax191                                                         controlFoldingReshapes);
21315994201cSMaheshRavishankar   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
21325994201cSMaheshRavishankar                                                      controlFoldingReshapes);
21335994201cSMaheshRavishankar }
21345994201cSMaheshRavishankar 
21352c58cde0SMahesh Ravishankar void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
21362c58cde0SMahesh Ravishankar     RewritePatternSet &patterns,
21372291705dSMahesh Ravishankar     const ControlFusionFn &controlFoldingReshapes) {
21382c58cde0SMahesh Ravishankar   patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
21392c58cde0SMahesh Ravishankar                                                       controlFoldingReshapes);
2140c886d66dSMax191   patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
2141c886d66dSMax191       patterns.getContext(), controlFoldingReshapes);
21422c58cde0SMahesh Ravishankar }
21432c58cde0SMahesh Ravishankar 
21445994201cSMaheshRavishankar void mlir::linalg::populateElementwiseOpsFusionPatterns(
21452291705dSMahesh Ravishankar     RewritePatternSet &patterns,
21462291705dSMahesh Ravishankar     const ControlFusionFn &controlElementwiseOpsFusion) {
21475994201cSMaheshRavishankar   auto *context = patterns.getContext();
21482291705dSMahesh Ravishankar   patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
21492291705dSMahesh Ravishankar   patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
21502291705dSMahesh Ravishankar                RemoveOutsDependency>(context);
2151da8a8e92SMahesh Ravishankar   // Add the patterns that clean up dead operands and results.
2152da8a8e92SMahesh Ravishankar   populateEraseUnusedOperandsAndResultsPatterns(patterns);
21535994201cSMaheshRavishankar }
21545994201cSMaheshRavishankar 
215583c65fbcSThomas Raoux void mlir::linalg::populateCollapseDimensions(
215683c65fbcSThomas Raoux     RewritePatternSet &patterns,
215783c65fbcSThomas Raoux     const GetCollapsableDimensionsFn &controlCollapseDimensions) {
21585c3ed392SAviad Cohen   patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
21595c3ed392SAviad Cohen                CollapseLinalgDimensions<linalg::CopyOp>>(
21605c3ed392SAviad Cohen       patterns.getContext(), controlCollapseDimensions);
216183c65fbcSThomas Raoux }
216283c65fbcSThomas Raoux 
216332288d37SMahesh Ravishankar //===---------------------------------------------------------------------===//
216432288d37SMahesh Ravishankar // Passes
216532288d37SMahesh Ravishankar //===---------------------------------------------------------------------===//
216632288d37SMahesh Ravishankar 
216732288d37SMahesh Ravishankar namespace {
216832288d37SMahesh Ravishankar 
216932288d37SMahesh Ravishankar /// Pass that fuses generic ops on tensors. Used only for testing.
21702291705dSMahesh Ravishankar // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the
21712291705dSMahesh Ravishankar // patterns added here heavily depends on the cost function used. Having an
21722291705dSMahesh Ravishankar // opinionated pass of this form is not recommended. Deprecate this pass in
21732291705dSMahesh Ravishankar // favor of test passes that check the functionality of each of the patterns
21742291705dSMahesh Ravishankar // added here individually.
217532288d37SMahesh Ravishankar struct LinalgElementwiseOpFusionPass
21761e98d488SQuinn Dawkins     : public impl::LinalgElementwiseOpFusionPassBase<
217767d0d7acSMichele Scuttari           LinalgElementwiseOpFusionPass> {
21781e98d488SQuinn Dawkins   using impl::LinalgElementwiseOpFusionPassBase<
21791e98d488SQuinn Dawkins       LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
218032288d37SMahesh Ravishankar   void runOnOperation() override {
218132288d37SMahesh Ravishankar     Operation *op = getOperation();
21822291705dSMahesh Ravishankar     MLIRContext *context = op->getContext();
21832291705dSMahesh Ravishankar     RewritePatternSet patterns(context);
21842291705dSMahesh Ravishankar 
21852291705dSMahesh Ravishankar     // Add folding with reshape by expansion patterns.
2186a7bfdc23SMahesh Ravishankar     ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) {
2187a7bfdc23SMahesh Ravishankar       Operation *producer = fusedOperand->get().getDefiningOp();
2188a7bfdc23SMahesh Ravishankar       return producer && producer->hasOneUse();
218932288d37SMahesh Ravishankar     };
21902291705dSMahesh Ravishankar 
21912291705dSMahesh Ravishankar     // Add elementwise op fusion patterns.
21922291705dSMahesh Ravishankar     populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
21930c090dccSMahesh Ravishankar     populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
2194a95ad2daSIan Wood     tensor::populateBubbleUpExpandShapePatterns(patterns);
21952291705dSMahesh Ravishankar 
21962291705dSMahesh Ravishankar     // General canonicalization patterns.
21974c48f016SMatthias Springer     affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
21982291705dSMahesh Ravishankar     GenericOp::getCanonicalizationPatterns(patterns, context);
21992291705dSMahesh Ravishankar     tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
22002291705dSMahesh Ravishankar     tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
22012291705dSMahesh Ravishankar     context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
22022291705dSMahesh Ravishankar         patterns);
22032291705dSMahesh Ravishankar 
22042291705dSMahesh Ravishankar     // Add constant folding patterns.
22052291705dSMahesh Ravishankar     populateConstantFoldLinalgOperations(patterns, defaultControlFn);
220632288d37SMahesh Ravishankar 
220732288d37SMahesh Ravishankar     // Use TopDownTraversal for compile time reasons
220832288d37SMahesh Ravishankar     GreedyRewriteConfig grc;
220932288d37SMahesh Ravishankar     grc.useTopDownTraversal = true;
221009dfc571SJacques Pienaar     (void)applyPatternsGreedily(op, std::move(patterns), grc);
221132288d37SMahesh Ravishankar   }
221232288d37SMahesh Ravishankar };
221332288d37SMahesh Ravishankar 
221432288d37SMahesh Ravishankar } // namespace
2215