15994201cSMaheshRavishankar //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// 25994201cSMaheshRavishankar // 35994201cSMaheshRavishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45994201cSMaheshRavishankar // See https://llvm.org/LICENSE.txt for license information. 55994201cSMaheshRavishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65994201cSMaheshRavishankar // 75994201cSMaheshRavishankar //===----------------------------------------------------------------------===// 85994201cSMaheshRavishankar // 95994201cSMaheshRavishankar // This file implements the linalg dialect Fusion on tensors operations pass. 105994201cSMaheshRavishankar // 115994201cSMaheshRavishankar //===----------------------------------------------------------------------===// 121fc096afSMehdi Amini 1367d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h" 1467d0d7acSMichele Scuttari 155994201cSMaheshRavishankar #include "mlir/Dialect/Affine/IR/AffineOps.h" 16*1f5335c1SMaheshRavishankar #include "mlir/Dialect/Arith/IR/Arith.h" 17abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h" 18b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h" 195994201cSMaheshRavishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20515c6170SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 21a95ad2daSIan Wood #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 225994201cSMaheshRavishankar #include "mlir/IR/AffineExpr.h" 235994201cSMaheshRavishankar #include "mlir/IR/AffineMap.h" 245994201cSMaheshRavishankar #include "mlir/IR/Matchers.h" 255994201cSMaheshRavishankar #include "mlir/IR/PatternMatch.h" 265994201cSMaheshRavishankar #include "mlir/Support/LLVM.h" 275994201cSMaheshRavishankar #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28a1fe1f5fSKazu Hirata #include <optional> 2969011a2aSMahesh Ravishankar #include <utility> 3067d0d7acSMichele Scuttari 3167d0d7acSMichele Scuttari namespace mlir { 321e98d488SQuinn Dawkins #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS 3367d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h.inc" 3467d0d7acSMichele Scuttari } // namespace mlir 355994201cSMaheshRavishankar 365994201cSMaheshRavishankar using namespace mlir; 375994201cSMaheshRavishankar using namespace mlir::linalg; 385994201cSMaheshRavishankar 3932288d37SMahesh Ravishankar //===---------------------------------------------------------------------===// 4032288d37SMahesh Ravishankar // Methods and patterns that fuse elementwise `linalg.generic` operations. 4132288d37SMahesh Ravishankar //===---------------------------------------------------------------------===// 4232288d37SMahesh Ravishankar 43b241226aSStephan Herhut /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 44b241226aSStephan Herhut /// the `producer` to use in the fused operation given the indexing map of the 45b241226aSStephan Herhut /// result of the producer in the consumer. 46b241226aSStephan Herhut static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 47b241226aSStephan Herhut OpOperand *producerOpOperand, AffineMap producerResultIndexMap, 48b241226aSStephan Herhut AffineMap fusedConsumerArgIndexMap) { 49b241226aSStephan Herhut // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 50b241226aSStephan Herhut // from consumer loop -> consumer arg tensor index/producer result tensor 51b241226aSStephan Herhut // index. The fused loop is same as the consumer loop. For each producer arg 52b241226aSStephan Herhut // the indexing map to be computed is a map from consumer loop -> producer 53b241226aSStephan Herhut // arg tensor index. 54b241226aSStephan Herhut // producerResultIndexMap is a map from producer loop -> tensor index. 55b241226aSStephan Herhut // Compute the inverse to get map from tensor index -> producer loop. 56b241226aSStephan Herhut // The inverse is a map from producer result tensor index -> producer loop. 57b241226aSStephan Herhut AffineMap invProducerResultIndexMap = 58b241226aSStephan Herhut inversePermutation(producerResultIndexMap); 59b241226aSStephan Herhut assert(invProducerResultIndexMap && 60652b39b4SAart Bik "expected producer result indexing map to be invertible"); 61b241226aSStephan Herhut 62b241226aSStephan Herhut LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner()); 63b241226aSStephan Herhut // argMap is a map from producer loop -> producer arg tensor index. 641227b8abSOleg Shyshkov AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand); 65b241226aSStephan Herhut 66b241226aSStephan Herhut // Compose argMap with invProducerResultIndexMap to get a map from 67b241226aSStephan Herhut // producer result tensor index -> producer arg tensor index. 68b241226aSStephan Herhut AffineMap t1 = argMap.compose(invProducerResultIndexMap); 69b241226aSStephan Herhut 70b241226aSStephan Herhut // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 71b241226aSStephan Herhut // consumer loop/ fused loop -> producer arg tensor index. 72b241226aSStephan Herhut return t1.compose(fusedConsumerArgIndexMap); 73b241226aSStephan Herhut } 74b241226aSStephan Herhut 754a4b233fSDanielLevi6 // Checks if the given operand can be dropped, and the remaining operands 764a4b233fSDanielLevi6 // of the fused producer & consumer after the fusion can still compute the 774a4b233fSDanielLevi6 // bounds of the op. 784a4b233fSDanielLevi6 static bool isOpOperandCanBeDroppedAfterFusedLinalgs( 794a4b233fSDanielLevi6 GenericOp producer, GenericOp consumer, 804a4b233fSDanielLevi6 ArrayRef<OpOperand *> opOperandsToIgnore) { 814a4b233fSDanielLevi6 SmallVector<AffineMap> indexingMaps; 824a4b233fSDanielLevi6 834a4b233fSDanielLevi6 SmallVector<GenericOp> ops = {producer, consumer}; 844a4b233fSDanielLevi6 for (auto &op : ops) { 854a4b233fSDanielLevi6 for (auto &opOperand : op->getOpOperands()) { 864a4b233fSDanielLevi6 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) { 874a4b233fSDanielLevi6 continue; 884a4b233fSDanielLevi6 } 894a4b233fSDanielLevi6 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand)); 904a4b233fSDanielLevi6 } 914a4b233fSDanielLevi6 } 9206514c55SIan Wood if (indexingMaps.empty()) { 9306514c55SIan Wood // If there are no indexing maps, the operand can only be dropped 9406514c55SIan Wood // if neither op has loops. 9506514c55SIan Wood return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0; 9606514c55SIan Wood } 974a4b233fSDanielLevi6 984a4b233fSDanielLevi6 // The concatanation of the remained indexing maps must be invertible, so 994a4b233fSDanielLevi6 // the bounds of the op can be still computed after dropping the selected 1004a4b233fSDanielLevi6 // operand. inversePermutation returns an empty AffineMap in case the 1014a4b233fSDanielLevi6 // concatanated indexing maps are not invertible. 10206514c55SIan Wood return inversePermutation(concatAffineMaps( 10306514c55SIan Wood indexingMaps, producer.getContext())) != AffineMap(); 1044a4b233fSDanielLevi6 } 1054a4b233fSDanielLevi6 106cf2d625aSAmir Bishara /// Returns a set of indices of the producer's results which would 107cf2d625aSAmir Bishara /// be preserved after the fusion. 1084a4b233fSDanielLevi6 /// * There is a chance that the implementation of the transformation does not 1094a4b233fSDanielLevi6 /// agree with the result of this method. This function gives a prediction based 1104a4b233fSDanielLevi6 /// on an optimized fusion. 1114a4b233fSDanielLevi6 llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults( 1124a4b233fSDanielLevi6 GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) { 113cf2d625aSAmir Bishara llvm::SmallDenseSet<int> preservedProducerResults; 1144a4b233fSDanielLevi6 llvm::SmallVector<OpOperand *> opOperandsToIgnore; 1154a4b233fSDanielLevi6 1164a4b233fSDanielLevi6 // The fusedOperand will be removed during the fusion 1174a4b233fSDanielLevi6 opOperandsToIgnore.emplace_back(fusedOperand); 1184a4b233fSDanielLevi6 119cf2d625aSAmir Bishara for (const auto &producerResult : llvm::enumerate(producer->getResults())) { 120cf2d625aSAmir Bishara auto *outputOperand = producer.getDpsInitOperand(producerResult.index()); 1214a4b233fSDanielLevi6 opOperandsToIgnore.emplace_back(outputOperand); 122cf2d625aSAmir Bishara if (producer.payloadUsesValueFromOperand(outputOperand) || 1234a4b233fSDanielLevi6 !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer, 1244a4b233fSDanielLevi6 opOperandsToIgnore) || 125cf2d625aSAmir Bishara llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) { 126cf2d625aSAmir Bishara return user != consumer.getOperation(); 127cf2d625aSAmir Bishara })) { 128cf2d625aSAmir Bishara preservedProducerResults.insert(producerResult.index()); 1294a4b233fSDanielLevi6 1304a4b233fSDanielLevi6 // In case the operand can't be dropped 131d8b6df2eSJie Fu (void)opOperandsToIgnore.pop_back_val(); 132cf2d625aSAmir Bishara } 133cf2d625aSAmir Bishara } 134cf2d625aSAmir Bishara return preservedProducerResults; 135cf2d625aSAmir Bishara } 136cf2d625aSAmir Bishara 1375994201cSMaheshRavishankar /// Conditions for elementwise fusion of generic operations. 138a7bfdc23SMahesh Ravishankar bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { 13969011a2aSMahesh Ravishankar if (!fusedOperand) 14069011a2aSMahesh Ravishankar return false; 14169011a2aSMahesh Ravishankar 142a7bfdc23SMahesh Ravishankar auto producer = fusedOperand->get().getDefiningOp<GenericOp>(); 143a7bfdc23SMahesh Ravishankar auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner()); 144a7bfdc23SMahesh Ravishankar 145a7bfdc23SMahesh Ravishankar // Check producer and consumer are generic ops. 146a7bfdc23SMahesh Ravishankar if (!producer || !consumer) 147a7bfdc23SMahesh Ravishankar return false; 148a7bfdc23SMahesh Ravishankar 149e3f75c1cSIvan Butygin // Consumer can have mixed semantics, just check operand itself has tensor 150e3f75c1cSIvan Butygin // type. Producer must have full tensor semantics to avoid potential 151e3f75c1cSIvan Butygin // aliasing between producer and consumer memrefs. 1520a8e3dd4SMatthias Springer if (!producer.hasPureTensorSemantics() || 1535550c821STres Popp !isa<RankedTensorType>(fusedOperand->get().getType())) 1545994201cSMaheshRavishankar return false; 1555994201cSMaheshRavishankar 1565994201cSMaheshRavishankar // Verify that 1575994201cSMaheshRavishankar // - the producer has all "parallel" iterator type. 1585994201cSMaheshRavishankar if (producer.getNumParallelLoops() != producer.getNumLoops()) 1595994201cSMaheshRavishankar return false; 1605994201cSMaheshRavishankar 1615994201cSMaheshRavishankar // Only allow fusing the producer of an input operand for now. 1625994201cSMaheshRavishankar // TODO: allow fusing the producer of an output operand. 163b4db15a9SAlexander Belyaev if (!consumer.isDpsInput(fusedOperand)) 1645994201cSMaheshRavishankar return false; 1655994201cSMaheshRavishankar 1665994201cSMaheshRavishankar // Get the consumer index map. The number of results of the consumer index 1675994201cSMaheshRavishankar // map must match the number of loops of the producer. 1681227b8abSOleg Shyshkov AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand); 1695994201cSMaheshRavishankar if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 1705994201cSMaheshRavishankar return false; 1715994201cSMaheshRavishankar 1725994201cSMaheshRavishankar // Finally the index_map for the result must be invertible. For now just 1735994201cSMaheshRavishankar // verify it is a permutation. 1745994201cSMaheshRavishankar AffineMap producerResultIndexMap = 175b4db15a9SAlexander Belyaev producer.getMatchingIndexingMap(producer.getDpsInitOperand(0)); 176b241226aSStephan Herhut if (!producerResultIndexMap.isPermutation()) 177b241226aSStephan Herhut return false; 178b241226aSStephan Herhut 179b241226aSStephan Herhut // Ensure that the fusion does not remove size information required to 180b241226aSStephan Herhut // get the loop bounds. For non-reduction generics, this is trivially the 181b241226aSStephan Herhut // case due to the output operand. For reductions, we need to check that after 182b241226aSStephan Herhut // the fusion, each loop dimension has at least one input that defines it. 183b241226aSStephan Herhut if ((consumer.getNumReductionLoops())) { 184d10d49dcSRiver Riddle BitVector coveredDims(consumer.getNumLoops(), false); 185b241226aSStephan Herhut 186b241226aSStephan Herhut auto addToCoveredDims = [&](AffineMap map) { 187b241226aSStephan Herhut for (auto result : map.getResults()) 1881609f1c2Slong.chen if (auto dimExpr = dyn_cast<AffineDimExpr>(result)) 189b241226aSStephan Herhut coveredDims[dimExpr.getPosition()] = true; 190b241226aSStephan Herhut }; 191b241226aSStephan Herhut 192b241226aSStephan Herhut for (auto pair : 193d2c0572bSJacques Pienaar llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) { 194b241226aSStephan Herhut Value operand = std::get<0>(pair); 195a7bfdc23SMahesh Ravishankar if (operand == fusedOperand->get()) 196b241226aSStephan Herhut continue; 197b241226aSStephan Herhut AffineMap operandMap = std::get<1>(pair); 198b241226aSStephan Herhut addToCoveredDims(operandMap); 1995994201cSMaheshRavishankar } 2005994201cSMaheshRavishankar 201b4db15a9SAlexander Belyaev for (OpOperand *operand : producer.getDpsInputOperands()) { 202b241226aSStephan Herhut AffineMap newIndexingMap = 203b241226aSStephan Herhut getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 204b241226aSStephan Herhut operand, producerResultIndexMap, consumerIndexMap); 205b241226aSStephan Herhut addToCoveredDims(newIndexingMap); 206b241226aSStephan Herhut } 207b241226aSStephan Herhut if (!coveredDims.all()) 208b241226aSStephan Herhut return false; 209b241226aSStephan Herhut } 2105994201cSMaheshRavishankar 211b241226aSStephan Herhut return true; 2125994201cSMaheshRavishankar } 2135994201cSMaheshRavishankar 2145994201cSMaheshRavishankar /// Generate the region of the fused tensor operation. The region of the fused 2155994201cSMaheshRavishankar /// op must be empty. 2162d4b9986SMahesh Ravishankar static void generateFusedElementwiseOpRegion( 2172d4b9986SMahesh Ravishankar RewriterBase &rewriter, GenericOp fusedOp, 2182d4b9986SMahesh Ravishankar AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, 2192d4b9986SMahesh Ravishankar unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) { 220a7bfdc23SMahesh Ravishankar auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp()); 221a7bfdc23SMahesh Ravishankar auto consumer = cast<GenericOp>(fusedOperand->getOwner()); 2225994201cSMaheshRavishankar // Build the region of the fused op. 2235994201cSMaheshRavishankar Block &producerBlock = producer->getRegion(0).front(); 2245994201cSMaheshRavishankar Block &consumerBlock = consumer->getRegion(0).front(); 2255994201cSMaheshRavishankar OpBuilder::InsertionGuard guard(rewriter); 22691d5653eSMatthias Springer Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion()); 22791d5653eSMatthias Springer IRMapping mapper; 2285994201cSMaheshRavishankar 2295994201cSMaheshRavishankar // 2. Add an index operation for every fused loop dimension and use the 2305994201cSMaheshRavishankar // `consumerToProducerLoopsMap` to map the producer indices. 2315994201cSMaheshRavishankar if (producer.hasIndexSemantics()) { 2325994201cSMaheshRavishankar // Add an index operation for every fused loop dimension. 2335994201cSMaheshRavishankar unsigned numFusedOpLoops = 2345994201cSMaheshRavishankar std::max(producer.getNumLoops(), consumer.getNumLoops()); 2355994201cSMaheshRavishankar SmallVector<Value> fusedIndices; 2365994201cSMaheshRavishankar fusedIndices.reserve(numFusedOpLoops); 2375994201cSMaheshRavishankar llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), 2385994201cSMaheshRavishankar std::back_inserter(fusedIndices), [&](uint64_t dim) { 2395994201cSMaheshRavishankar return rewriter.create<IndexOp>(producer.getLoc(), dim); 2405994201cSMaheshRavishankar }); 2415994201cSMaheshRavishankar for (IndexOp indexOp : 2425994201cSMaheshRavishankar llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { 2434c48f016SMatthias Springer Value newIndex = rewriter.create<affine::AffineApplyOp>( 2445994201cSMaheshRavishankar producer.getLoc(), 245d3b3f765SJacques Pienaar consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices); 2465994201cSMaheshRavishankar mapper.map(indexOp.getResult(), newIndex); 2475994201cSMaheshRavishankar } 2485994201cSMaheshRavishankar } 2495994201cSMaheshRavishankar // TODO: allow fusing the producer of an output operand. 250b4db15a9SAlexander Belyaev assert(consumer.isDpsInput(fusedOperand) && 2515994201cSMaheshRavishankar "expected producer of input operand"); 2525994201cSMaheshRavishankar // 3. Consumer input operands up to consumerIdx (exclusive). 2535994201cSMaheshRavishankar for (BlockArgument bbArg : consumerBlock.getArguments().take_front( 254a7bfdc23SMahesh Ravishankar fusedOperand->getOperandNumber())) // input assumption. 255e084679fSRiver Riddle mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 2565994201cSMaheshRavishankar 2575994201cSMaheshRavishankar // Replacing consumerIdx requires getting the cloned, yielded, value from 2585994201cSMaheshRavishankar // the (cloned) producer block. This happens in step 9. 2595994201cSMaheshRavishankar 2605994201cSMaheshRavishankar // 4. Splice in producer's input operands. 2615994201cSMaheshRavishankar for (BlockArgument bbArg : 262b4db15a9SAlexander Belyaev producerBlock.getArguments().take_front(producer.getNumDpsInputs())) 263e084679fSRiver Riddle mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 2645994201cSMaheshRavishankar 2655994201cSMaheshRavishankar // 5. Remaining consumer's input operands (drop past index `consumerIdx`). 2665994201cSMaheshRavishankar for (BlockArgument bbArg : 2675994201cSMaheshRavishankar consumerBlock.getArguments() 268b4db15a9SAlexander Belyaev .take_front(consumer.getNumDpsInputs()) 269a7bfdc23SMahesh Ravishankar .drop_front(fusedOperand->getOperandNumber() + 1)) 270e084679fSRiver Riddle mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 271a7bfdc23SMahesh Ravishankar 272a7bfdc23SMahesh Ravishankar // 6. All of the producer's output operands 273215666d9SAdrian Kuegel for (const auto &bbArg : llvm::enumerate( 2742d4b9986SMahesh Ravishankar producerBlock.getArguments().take_back(producer.getNumDpsInits()))) { 2752d4b9986SMahesh Ravishankar if (!preservedProducerResults.count(bbArg.index())) 2762d4b9986SMahesh Ravishankar continue; 2772d4b9986SMahesh Ravishankar mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(), 2782d4b9986SMahesh Ravishankar bbArg.value().getLoc())); 2792d4b9986SMahesh Ravishankar } 280a7bfdc23SMahesh Ravishankar 281a7bfdc23SMahesh Ravishankar // 7. All of consumer's output operands. 2825994201cSMaheshRavishankar for (BlockArgument bbArg : 283b4db15a9SAlexander Belyaev consumerBlock.getArguments().take_back(consumer.getNumDpsInits())) 284e084679fSRiver Riddle mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 2855994201cSMaheshRavishankar 2865994201cSMaheshRavishankar // 8. Clone all producer operations except for the yield and index operations 2875994201cSMaheshRavishankar // to the fused operation. 2885994201cSMaheshRavishankar for (auto &op : producerBlock.without_terminator()) { 2895994201cSMaheshRavishankar if (!isa<IndexOp>(op)) 2905994201cSMaheshRavishankar rewriter.clone(op, mapper); 2915994201cSMaheshRavishankar } 2925994201cSMaheshRavishankar // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just 2935994201cSMaheshRavishankar // forward the yield operand. 294a7bfdc23SMahesh Ravishankar auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator()); 295a7bfdc23SMahesh Ravishankar unsigned producerResultNumber = 2965550c821STres Popp cast<OpResult>(fusedOperand->get()).getResultNumber(); 2975994201cSMaheshRavishankar Value replacement = 298a7bfdc23SMahesh Ravishankar mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber)); 299a7bfdc23SMahesh Ravishankar 3005994201cSMaheshRavishankar // Sanity checks, if replacement is not already in the mapper then it must be 3015994201cSMaheshRavishankar // produced outside. 302a7bfdc23SMahesh Ravishankar if (replacement == producerYieldOp.getOperand(producerResultNumber)) { 3035550c821STres Popp if (auto bb = dyn_cast<BlockArgument>(replacement)) 3045994201cSMaheshRavishankar assert(bb.getOwner() != &producerBlock && 3055994201cSMaheshRavishankar "yielded block argument must have been mapped"); 3065994201cSMaheshRavishankar else 3075994201cSMaheshRavishankar assert(!producer->isAncestor(replacement.getDefiningOp()) && 3085994201cSMaheshRavishankar "yielded value must have been mapped"); 3095994201cSMaheshRavishankar } 310a7bfdc23SMahesh Ravishankar mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()), 3115994201cSMaheshRavishankar replacement); 3125994201cSMaheshRavishankar // 10. Clone operations from the consumer to the fused op. 313a7bfdc23SMahesh Ravishankar for (auto &op : consumerBlock.without_terminator()) 3145994201cSMaheshRavishankar rewriter.clone(op, mapper); 3155994201cSMaheshRavishankar 316a7bfdc23SMahesh Ravishankar // 11. Include the final yield (which is the remapped values for all the 317a7bfdc23SMahesh Ravishankar // yield) 318a7bfdc23SMahesh Ravishankar auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator()); 319a7bfdc23SMahesh Ravishankar SmallVector<Value> fusedYieldValues; 320a7bfdc23SMahesh Ravishankar fusedYieldValues.reserve(producerYieldOp.getNumOperands() + 321a7bfdc23SMahesh Ravishankar consumerYieldOp.getNumOperands()); 322215666d9SAdrian Kuegel for (const auto &producerYieldVal : 323215666d9SAdrian Kuegel llvm::enumerate(producerYieldOp.getOperands())) { 3242d4b9986SMahesh Ravishankar if (preservedProducerResults.count(producerYieldVal.index())) 3252d4b9986SMahesh Ravishankar fusedYieldValues.push_back( 3262d4b9986SMahesh Ravishankar mapper.lookupOrDefault(producerYieldVal.value())); 3272d4b9986SMahesh Ravishankar } 328a7bfdc23SMahesh Ravishankar for (auto consumerYieldVal : consumerYieldOp.getOperands()) 329a7bfdc23SMahesh Ravishankar fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal)); 330a7bfdc23SMahesh Ravishankar rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues); 331a7bfdc23SMahesh Ravishankar 3325994201cSMaheshRavishankar // Sanity checks. 3335994201cSMaheshRavishankar assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && 3345994201cSMaheshRavishankar "Ill-formed GenericOp region"); 3355994201cSMaheshRavishankar } 3365994201cSMaheshRavishankar 33769011a2aSMahesh Ravishankar FailureOr<mlir::linalg::ElementwiseOpFusionResult> 338a7bfdc23SMahesh Ravishankar mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, 339a7bfdc23SMahesh Ravishankar OpOperand *fusedOperand) { 340a7bfdc23SMahesh Ravishankar assert(areElementwiseOpsFusable(fusedOperand) && 341a7bfdc23SMahesh Ravishankar "expected elementwise operation pre-conditions to pass"); 3425550c821STres Popp auto producerResult = cast<OpResult>(fusedOperand->get()); 343a7bfdc23SMahesh Ravishankar auto producer = cast<GenericOp>(producerResult.getOwner()); 344a7bfdc23SMahesh Ravishankar auto consumer = cast<GenericOp>(fusedOperand->getOwner()); 3455994201cSMaheshRavishankar // TODO: allow fusing the producer of an output operand. 346b4db15a9SAlexander Belyaev assert(consumer.isDpsInput(fusedOperand) && 3475994201cSMaheshRavishankar "expected producer of input operand"); 3484a4b233fSDanielLevi6 /// Find the results of the producer that have uses outside of the consumer, 3494a4b233fSDanielLevi6 /// after the fusion. 350cf2d625aSAmir Bishara llvm::SmallDenseSet<int> preservedProducerResults = 3514a4b233fSDanielLevi6 mlir::linalg::getPreservedProducerResults(producer, consumer, 3524a4b233fSDanielLevi6 fusedOperand); 3535994201cSMaheshRavishankar 3545994201cSMaheshRavishankar // Compute the fused operands list and indexing maps. 355a7bfdc23SMahesh Ravishankar SmallVector<Value> fusedInputOperands, fusedOutputOperands; 356a7bfdc23SMahesh Ravishankar SmallVector<Type> fusedResultTypes; 3575994201cSMaheshRavishankar SmallVector<AffineMap> fusedIndexMaps; 358b4db15a9SAlexander Belyaev fusedInputOperands.reserve(producer.getNumDpsInputs() + 359b4db15a9SAlexander Belyaev consumer.getNumDpsInputs()); 3602d4b9986SMahesh Ravishankar fusedOutputOperands.reserve(preservedProducerResults.size() + 361b4db15a9SAlexander Belyaev consumer.getNumDpsInits()); 3622d4b9986SMahesh Ravishankar fusedResultTypes.reserve(preservedProducerResults.size() + 363b4db15a9SAlexander Belyaev consumer.getNumDpsInits()); 364a7cccb9cSAlexander Belyaev fusedIndexMaps.reserve(producer->getNumOperands() + 365a7cccb9cSAlexander Belyaev consumer->getNumOperands()); 3665994201cSMaheshRavishankar // In the following, numbering matches that of `generateFusedTensorOpRegion`. 3675994201cSMaheshRavishankar // 3. Consumer input operands/maps up to consumerIdx (exclusive). 368b4db15a9SAlexander Belyaev auto consumerInputs = consumer.getDpsInputOperands(); 369a7cccb9cSAlexander Belyaev auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) { 370a7cccb9cSAlexander Belyaev return operand == fusedOperand; 371a7cccb9cSAlexander Belyaev }); 3725994201cSMaheshRavishankar assert(it != consumerInputs.end() && "expected to find the consumer operand"); 3735994201cSMaheshRavishankar for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { 374a7bfdc23SMahesh Ravishankar fusedInputOperands.push_back(opOperand->get()); 3751227b8abSOleg Shyshkov fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); 3765994201cSMaheshRavishankar } 3775994201cSMaheshRavishankar // 4. Splice in producer's input operands/maps. 3785994201cSMaheshRavishankar AffineMap producerResultIndexMap = 3791227b8abSOleg Shyshkov producer.getIndexingMapMatchingResult(producerResult); 380b4db15a9SAlexander Belyaev for (OpOperand *opOperand : producer.getDpsInputOperands()) { 381a7bfdc23SMahesh Ravishankar fusedInputOperands.push_back(opOperand->get()); 3825994201cSMaheshRavishankar // Compute indexing maps for the producer args in the fused operation. 3835994201cSMaheshRavishankar AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 3845994201cSMaheshRavishankar opOperand, producerResultIndexMap, 3851227b8abSOleg Shyshkov consumer.getMatchingIndexingMap(fusedOperand)); 3865994201cSMaheshRavishankar fusedIndexMaps.push_back(map); 3875994201cSMaheshRavishankar } 3885994201cSMaheshRavishankar // 5. Remaining consumer's input operands/maps (drop past index 3895994201cSMaheshRavishankar // `consumerIdx`). 3905994201cSMaheshRavishankar for (OpOperand *opOperand : 3915994201cSMaheshRavishankar llvm::make_range(std::next(it), consumerInputs.end())) { 392a7bfdc23SMahesh Ravishankar fusedInputOperands.push_back(opOperand->get()); 3931227b8abSOleg Shyshkov fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); 3945994201cSMaheshRavishankar } 395a7bfdc23SMahesh Ravishankar 396a7bfdc23SMahesh Ravishankar // 6. Collect all of the producer outputs. 3970b2197b0SMatthias Springer for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) { 3982d4b9986SMahesh Ravishankar if (!preservedProducerResults.count(opOperand.index())) 3992d4b9986SMahesh Ravishankar continue; 4002d4b9986SMahesh Ravishankar 4010b2197b0SMatthias Springer fusedOutputOperands.push_back(opOperand.value().get()); 402a7bfdc23SMahesh Ravishankar AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 4030b2197b0SMatthias Springer &opOperand.value(), producerResultIndexMap, 4041227b8abSOleg Shyshkov consumer.getMatchingIndexingMap(fusedOperand)); 405a7bfdc23SMahesh Ravishankar fusedIndexMaps.push_back(map); 4060b2197b0SMatthias Springer fusedResultTypes.push_back(opOperand.value().get().getType()); 407a7bfdc23SMahesh Ravishankar } 408a7bfdc23SMahesh Ravishankar 409a7bfdc23SMahesh Ravishankar // 7. All of consumer's output operands (skip operands: added by the builder). 4100b2197b0SMatthias Springer for (OpOperand &opOperand : consumer.getDpsInitsMutable()) { 4110b2197b0SMatthias Springer fusedOutputOperands.push_back(opOperand.get()); 4120b2197b0SMatthias Springer fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand)); 4130b2197b0SMatthias Springer Type resultType = opOperand.get().getType(); 4145550c821STres Popp if (!isa<MemRefType>(resultType)) 415e3f75c1cSIvan Butygin fusedResultTypes.push_back(resultType); 416a7bfdc23SMahesh Ravishankar } 4175994201cSMaheshRavishankar 4185994201cSMaheshRavishankar // Generate the fused op. 4195994201cSMaheshRavishankar auto fusedOp = rewriter.create<GenericOp>( 420a7bfdc23SMahesh Ravishankar consumer.getLoc(), fusedResultTypes, fusedInputOperands, 421a7bfdc23SMahesh Ravishankar fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 422d3b3f765SJacques Pienaar consumer.getIteratorTypes(), 4235994201cSMaheshRavishankar /*doc=*/nullptr, 4245994201cSMaheshRavishankar /*library_call=*/nullptr); 425a99e06aaSMaheshRavishankar if (!fusedOp.getShapesToLoopsMap()) { 426a99e06aaSMaheshRavishankar // Fused op has invalid indexing maps. Typically this means something is off 427a99e06aaSMaheshRavishankar // in the input, but going ahead here would result in verification errors. 428a99e06aaSMaheshRavishankar // So cleanup and abort. 429a99e06aaSMaheshRavishankar rewriter.eraseOp(fusedOp); 430a7bfdc23SMahesh Ravishankar return rewriter.notifyMatchFailure( 431a7bfdc23SMahesh Ravishankar fusedOp, "fused op failed loop bound computation check"); 432a99e06aaSMaheshRavishankar } 4335994201cSMaheshRavishankar 4345994201cSMaheshRavishankar // Construct an AffineMap from consumer loops to producer loops. 4355994201cSMaheshRavishankar // consumer loop -> tensor index 4361227b8abSOleg Shyshkov AffineMap consumerResultIndexMap = 4371227b8abSOleg Shyshkov consumer.getMatchingIndexingMap(fusedOperand); 4385994201cSMaheshRavishankar // tensor index -> producer loop 4395994201cSMaheshRavishankar AffineMap invProducerResultIndexMap = 4405994201cSMaheshRavishankar inversePermutation(producerResultIndexMap); 4415994201cSMaheshRavishankar assert(invProducerResultIndexMap && 4425994201cSMaheshRavishankar "expected producer result indexig map to be invertible"); 4435994201cSMaheshRavishankar // consumer loop -> producer loop 4445994201cSMaheshRavishankar AffineMap consumerToProducerLoopsMap = 4455994201cSMaheshRavishankar invProducerResultIndexMap.compose(consumerResultIndexMap); 4465994201cSMaheshRavishankar 4472d4b9986SMahesh Ravishankar generateFusedElementwiseOpRegion( 4482d4b9986SMahesh Ravishankar rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand, 4492d4b9986SMahesh Ravishankar consumer.getNumLoops(), preservedProducerResults); 45069011a2aSMahesh Ravishankar ElementwiseOpFusionResult result; 45169011a2aSMahesh Ravishankar result.fusedOp = fusedOp; 45269011a2aSMahesh Ravishankar int resultNum = 0; 45369011a2aSMahesh Ravishankar for (auto [index, producerResult] : llvm::enumerate(producer->getResults())) 45469011a2aSMahesh Ravishankar if (preservedProducerResults.count(index)) 45569011a2aSMahesh Ravishankar result.replacements[producerResult] = fusedOp->getResult(resultNum++); 45669011a2aSMahesh Ravishankar for (auto consumerResult : consumer->getResults()) 45769011a2aSMahesh Ravishankar result.replacements[consumerResult] = fusedOp->getResult(resultNum++); 45869011a2aSMahesh Ravishankar return result; 45932288d37SMahesh Ravishankar } 46032288d37SMahesh Ravishankar 46132288d37SMahesh Ravishankar namespace { 46232288d37SMahesh Ravishankar /// Patterns to fuse a generic op, with the producer of its operands. 46332288d37SMahesh Ravishankar class FuseElementwiseOps : public OpRewritePattern<GenericOp> { 46432288d37SMahesh Ravishankar public: 4652291705dSMahesh Ravishankar FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun, 46632288d37SMahesh Ravishankar PatternBenefit benefit = 1) 4672291705dSMahesh Ravishankar : OpRewritePattern<GenericOp>(context, benefit), 4682291705dSMahesh Ravishankar controlFn(std::move(fun)) {} 46932288d37SMahesh Ravishankar 47032288d37SMahesh Ravishankar LogicalResult matchAndRewrite(GenericOp genericOp, 47132288d37SMahesh Ravishankar PatternRewriter &rewriter) const override { 47232288d37SMahesh Ravishankar // Find the first operand that is defined by another generic op on tensors. 473a7cccb9cSAlexander Belyaev for (OpOperand &opOperand : genericOp->getOpOperands()) { 474a7cccb9cSAlexander Belyaev if (!areElementwiseOpsFusable(&opOperand)) 47532288d37SMahesh Ravishankar continue; 476a7cccb9cSAlexander Belyaev if (!controlFn(&opOperand)) 477a7bfdc23SMahesh Ravishankar continue; 478a7bfdc23SMahesh Ravishankar 47969011a2aSMahesh Ravishankar Operation *producer = opOperand.get().getDefiningOp(); 4805c03c056SAart Bik 481986287e7SMatthias Springer // Find the producer of the operand. 482986287e7SMatthias Springer FailureOr<ElementwiseOpFusionResult> fusionResult = 483986287e7SMatthias Springer fuseElementwiseOps(rewriter, &opOperand); 484986287e7SMatthias Springer if (failed(fusionResult)) 485986287e7SMatthias Springer return rewriter.notifyMatchFailure(genericOp, "fusion failed"); 486986287e7SMatthias Springer 4875c03c056SAart Bik // Perform the fusion. 48869011a2aSMahesh Ravishankar for (auto [origVal, replacement] : fusionResult->replacements) { 4892c40a0a6SMatthias Springer rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) { 49069011a2aSMahesh Ravishankar // Only replace consumer uses. 49169011a2aSMahesh Ravishankar return use.get().getDefiningOp() != producer; 49269011a2aSMahesh Ravishankar }); 49332288d37SMahesh Ravishankar } 49469011a2aSMahesh Ravishankar rewriter.eraseOp(genericOp); 49569011a2aSMahesh Ravishankar return success(); 49632288d37SMahesh Ravishankar } 49732288d37SMahesh Ravishankar return failure(); 49832288d37SMahesh Ravishankar } 49932288d37SMahesh Ravishankar 50032288d37SMahesh Ravishankar private: 5012291705dSMahesh Ravishankar ControlFusionFn controlFn; 50232288d37SMahesh Ravishankar }; 50332288d37SMahesh Ravishankar } // namespace 50432288d37SMahesh Ravishankar 50532288d37SMahesh Ravishankar //===---------------------------------------------------------------------===// 50632288d37SMahesh Ravishankar // Methods and patterns that fuse reshape ops with elementwise operations by 50732288d37SMahesh Ravishankar // expanding the dimensionality of the elementwise operations. 50832288d37SMahesh Ravishankar //===---------------------------------------------------------------------===// 50932288d37SMahesh Ravishankar 5103f18f6a2SQuinn Dawkins /// Conditions for folding a structured linalg operation with a reshape op by 5113f18f6a2SQuinn Dawkins /// expanding the iteration space dimensionality for tensor operations. These 5123f18f6a2SQuinn Dawkins /// are preconditions assumed by `foldReshapeByDimExpansion` which implements 5133f18f6a2SQuinn Dawkins /// the following fusion pattern. 5145994201cSMaheshRavishankar /// 5155994201cSMaheshRavishankar /// Consider 5165994201cSMaheshRavishankar /// 5175994201cSMaheshRavishankar /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 5185994201cSMaheshRavishankar /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 5195994201cSMaheshRavishankar /// affine_map<(d0, d1, d2) -> (d1, d2)>, 5205994201cSMaheshRavishankar /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 521206365bfSAlexander Belyaev /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] 5225994201cSMaheshRavishankar /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 5235994201cSMaheshRavishankar /// 5243f18f6a2SQuinn Dawkins /// The reshape can be folded into the `linalgOp` if its loop dimensionality 5251ad9b266Slorenzo chelini /// is increased to match the result (operand) of the tensor.expand_shape. 5263f18f6a2SQuinn Dawkins /// The indexing_map of the fused tensor in the `linalgOp` and the 5275994201cSMaheshRavishankar /// reassociation map helps compute the indexing maps of the modified op. 5285994201cSMaheshRavishankar /// For the above example, based on the reassociation map it 5295994201cSMaheshRavishankar /// can be concluded that 5305994201cSMaheshRavishankar /// 5315994201cSMaheshRavishankar /// - The loop used to access the first dimension of the fused tensor is split 5325994201cSMaheshRavishankar /// into two. 5335994201cSMaheshRavishankar /// - The loop used to access the second dimension of the fused tensor is kept 5345994201cSMaheshRavishankar /// as is. 5355994201cSMaheshRavishankar /// - The loop used to access the third dimension of the fused tensor is split 5365994201cSMaheshRavishankar /// into three. 5375994201cSMaheshRavishankar /// 5385994201cSMaheshRavishankar /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 5395994201cSMaheshRavishankar /// op, then 5405994201cSMaheshRavishankar /// 5415994201cSMaheshRavishankar /// d0 -> e0, e1 5425994201cSMaheshRavishankar /// d1 -> e2, e3, e4 5435994201cSMaheshRavishankar /// d2 -> e5 5445994201cSMaheshRavishankar /// 5453f18f6a2SQuinn Dawkins /// substituting this, the structured op can be rewritten as 5465994201cSMaheshRavishankar /// 5475994201cSMaheshRavishankar /// %d = linalg.generic ins(%0, %1 : ) 5485994201cSMaheshRavishankar /// indexing_maps = 5495994201cSMaheshRavishankar /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 5505994201cSMaheshRavishankar /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 5515994201cSMaheshRavishankar /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 5525994201cSMaheshRavishankar /// 5535994201cSMaheshRavishankar /// Since operands to the linalg generic are now 5D, reshapes can be introduced 5545994201cSMaheshRavishankar /// to make it consistent 5555994201cSMaheshRavishankar /// 556206365bfSAlexander Belyaev /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]] 5575994201cSMaheshRavishankar /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 558206365bfSAlexander Belyaev /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]] 5595994201cSMaheshRavishankar /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 5605994201cSMaheshRavishankar /// 5615994201cSMaheshRavishankar /// The added reshapes are again expanding patterns, so they will get fused 5625994201cSMaheshRavishankar /// with its producers if possible. 5633f18f6a2SQuinn Dawkins static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, 5645994201cSMaheshRavishankar OpOperand *fusableOpOperand) { 5655994201cSMaheshRavishankar // Is fusable only if: 5665994201cSMaheshRavishankar // - All the indexing maps for operands and results are projected 5675994201cSMaheshRavishankar // permutations. 5685994201cSMaheshRavishankar // - The fused tensor is not a scalar. 5693f18f6a2SQuinn Dawkins // - All the loops for the reshaped operand are parallel loops. 5703f18f6a2SQuinn Dawkins SmallVector<utils::IteratorType> iteratorTypes = 5713f18f6a2SQuinn Dawkins linalgOp.getIteratorTypesArray(); 5723f18f6a2SQuinn Dawkins AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand); 5733f18f6a2SQuinn Dawkins return linalgOp.hasPureTensorSemantics() && 5743f18f6a2SQuinn Dawkins llvm::all_of(linalgOp.getIndexingMaps().getValue(), 5755994201cSMaheshRavishankar [](Attribute attr) { 5765550c821STres Popp return cast<AffineMapAttr>(attr) 5775994201cSMaheshRavishankar .getValue() 5785994201cSMaheshRavishankar .isProjectedPermutation(); 5795994201cSMaheshRavishankar }) && 5803f18f6a2SQuinn Dawkins operandMap.getNumResults() > 0 && 5813f18f6a2SQuinn Dawkins llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) { 5823f18f6a2SQuinn Dawkins return isParallelIterator( 5833f18f6a2SQuinn Dawkins iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]); 5843f18f6a2SQuinn Dawkins }); 5855994201cSMaheshRavishankar } 5865994201cSMaheshRavishankar 5875994201cSMaheshRavishankar namespace { 5885994201cSMaheshRavishankar /// Information needed to expand a generic operation to fold the reshape with 5895994201cSMaheshRavishankar /// it. 5905994201cSMaheshRavishankar class ExpansionInfo { 5915994201cSMaheshRavishankar public: 5925994201cSMaheshRavishankar // Computes the mapping from original dimensions of the op to the dimensions 5935994201cSMaheshRavishankar // of the expanded op given the `indexingMap` of the fused operand/result of 5945994201cSMaheshRavishankar // the generic op, the `reassocationMaps` of the reshape op and the shape of 5955994201cSMaheshRavishankar // the expanded op. 5965994201cSMaheshRavishankar LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, 5975994201cSMaheshRavishankar ArrayRef<AffineMap> reassociationMaps, 5985994201cSMaheshRavishankar ArrayRef<int64_t> expandedShape, 5994317a3dfSMaheshRavishankar ArrayRef<int64_t> collapsedShape, 6005994201cSMaheshRavishankar PatternRewriter &rewriter); 6015994201cSMaheshRavishankar unsigned getOrigOpNumDims() const { return reassociation.size(); } 6025994201cSMaheshRavishankar unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 6035994201cSMaheshRavishankar ReassociationIndicesRef getExpandedDims(unsigned i) const { 6045994201cSMaheshRavishankar return reassociation[i]; 6055994201cSMaheshRavishankar } 6065994201cSMaheshRavishankar ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 6075994201cSMaheshRavishankar return expandedShapeMap[i]; 6085994201cSMaheshRavishankar } 6094317a3dfSMaheshRavishankar ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; } 6105994201cSMaheshRavishankar 6115994201cSMaheshRavishankar private: 6125994201cSMaheshRavishankar /// Reassociation from the dimensions in the original operation to the 6135994201cSMaheshRavishankar /// dimension of the expanded operation. 6145994201cSMaheshRavishankar SmallVector<ReassociationIndices> reassociation; 6155994201cSMaheshRavishankar /// Mapping from extent of loops in the original operation, to the extent of 6165994201cSMaheshRavishankar /// loops in the expanded operation. 6175994201cSMaheshRavishankar SmallVector<SmallVector<int64_t>> expandedShapeMap; 6184317a3dfSMaheshRavishankar /// Extent of the loop in the original operation. 6194317a3dfSMaheshRavishankar SmallVector<int64_t> originalLoopExtent; 6205994201cSMaheshRavishankar unsigned expandedOpNumDims; 6215994201cSMaheshRavishankar }; 6225994201cSMaheshRavishankar } // namespace 6235994201cSMaheshRavishankar 6245994201cSMaheshRavishankar LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 6255994201cSMaheshRavishankar OpOperand *fusableOpOperand, 6265994201cSMaheshRavishankar ArrayRef<AffineMap> reassociationMaps, 6275994201cSMaheshRavishankar ArrayRef<int64_t> expandedShape, 6284317a3dfSMaheshRavishankar ArrayRef<int64_t> collapsedShape, 6295994201cSMaheshRavishankar PatternRewriter &rewriter) { 6305994201cSMaheshRavishankar if (reassociationMaps.empty()) 6315994201cSMaheshRavishankar return failure(); 6321227b8abSOleg Shyshkov AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand); 6335994201cSMaheshRavishankar 634919e459fSHanhan Wang SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges(); 635919e459fSHanhan Wang originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end()); 6365994201cSMaheshRavishankar 6375994201cSMaheshRavishankar reassociation.clear(); 6385994201cSMaheshRavishankar expandedShapeMap.clear(); 6395994201cSMaheshRavishankar // Compute the number of dimension in the expanded op that correspond to each 6405994201cSMaheshRavishankar // dimension of the original op. 6415994201cSMaheshRavishankar SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1); 6425994201cSMaheshRavishankar expandedShapeMap.resize(fusedIndexMap.getNumDims()); 643e4853be2SMehdi Amini for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 6441609f1c2Slong.chen unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition(); 6455994201cSMaheshRavishankar AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 6465994201cSMaheshRavishankar numExpandedDims[pos] = foldedDims.getNumResults(); 6475994201cSMaheshRavishankar ArrayRef<int64_t> shape = 6485994201cSMaheshRavishankar expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 6495994201cSMaheshRavishankar expandedShapeMap[pos].assign(shape.begin(), shape.end()); 6505994201cSMaheshRavishankar } 6515994201cSMaheshRavishankar // The remaining dimensions remain the same. 6525994201cSMaheshRavishankar for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 6535994201cSMaheshRavishankar if (expandedShapeMap[i].empty()) 6544317a3dfSMaheshRavishankar expandedShapeMap[i] = {originalLoopExtent[i]}; 6555994201cSMaheshRavishankar 6565994201cSMaheshRavishankar // Compute reassociation map from the original op to the expanded op. 6575994201cSMaheshRavishankar unsigned sum = 0; 6585994201cSMaheshRavishankar reassociation.reserve(fusedIndexMap.getNumDims()); 659e4853be2SMehdi Amini for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) { 6605994201cSMaheshRavishankar auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 6615994201cSMaheshRavishankar reassociation.emplace_back(seq.begin(), seq.end()); 6625994201cSMaheshRavishankar sum += numFoldedDim.value(); 6635994201cSMaheshRavishankar } 6645994201cSMaheshRavishankar expandedOpNumDims = sum; 6655994201cSMaheshRavishankar return success(); 6665994201cSMaheshRavishankar } 6675994201cSMaheshRavishankar 66897069a86SGaurav Shukla /// Expanding the body of a linalg operation requires adaptations of the 66997069a86SGaurav Shukla /// accessed loop indices. Specifically, access of indices in the original 67097069a86SGaurav Shukla /// operation need to be replaced with linearizations of indices in the expanded 67197069a86SGaurav Shukla /// op. That requires the shape of the expanded dimensions to be static (at 67297069a86SGaurav Shukla /// least all but the most significant). For now check that these are all 67397069a86SGaurav Shukla /// statically sized. Note that this could be extended to handle dynamic case, 67497069a86SGaurav Shukla /// but the implementation below uses `affine.apply` which seems to have issues 67597069a86SGaurav Shukla /// when the shapes are not static. 6763f18f6a2SQuinn Dawkins static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp, 6775994201cSMaheshRavishankar const ExpansionInfo &expansionInfo, 6785994201cSMaheshRavishankar PatternRewriter &rewriter) { 6793f18f6a2SQuinn Dawkins if (!linalgOp.hasIndexSemantics()) 6805994201cSMaheshRavishankar return success(); 6815994201cSMaheshRavishankar for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 6825994201cSMaheshRavishankar ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 6835994201cSMaheshRavishankar if (expandedShape.size() == 1) 6845994201cSMaheshRavishankar continue; 6855994201cSMaheshRavishankar for (int64_t shape : expandedShape.drop_front()) { 6865994201cSMaheshRavishankar if (ShapedType::isDynamic(shape)) { 6875994201cSMaheshRavishankar return rewriter.notifyMatchFailure( 6883f18f6a2SQuinn Dawkins linalgOp, "cannot expand due to index semantics and dynamic dims"); 6895994201cSMaheshRavishankar } 6905994201cSMaheshRavishankar } 6915994201cSMaheshRavishankar } 6925994201cSMaheshRavishankar return success(); 6935994201cSMaheshRavishankar } 6945994201cSMaheshRavishankar 6955994201cSMaheshRavishankar /// Return the indexing map to use in the expanded op for a given the 6965994201cSMaheshRavishankar /// `indexingMap` of the original operation. 6975994201cSMaheshRavishankar static AffineMap 6985994201cSMaheshRavishankar getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 6995994201cSMaheshRavishankar const ExpansionInfo &expansionInfo) { 7005994201cSMaheshRavishankar SmallVector<AffineExpr> newExprs; 7015994201cSMaheshRavishankar for (AffineExpr expr : indexingMap.getResults()) { 7021609f1c2Slong.chen unsigned pos = cast<AffineDimExpr>(expr).getPosition(); 7035994201cSMaheshRavishankar SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 7045994201cSMaheshRavishankar llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 7055994201cSMaheshRavishankar return builder.getAffineDimExpr(static_cast<unsigned>(v)); 7065994201cSMaheshRavishankar })); 7075994201cSMaheshRavishankar newExprs.append(expandedExprs.begin(), expandedExprs.end()); 7085994201cSMaheshRavishankar } 7095994201cSMaheshRavishankar return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 7105994201cSMaheshRavishankar indexingMap.getNumSymbols(), newExprs, 7115994201cSMaheshRavishankar builder.getContext()); 7125994201cSMaheshRavishankar } 7135994201cSMaheshRavishankar 7145994201cSMaheshRavishankar /// Return the type of the operand/result to use in the expanded op given the 7155994201cSMaheshRavishankar /// type in the original op. 7165994201cSMaheshRavishankar static RankedTensorType getExpandedType(RankedTensorType originalType, 7175994201cSMaheshRavishankar AffineMap indexingMap, 7185994201cSMaheshRavishankar const ExpansionInfo &expansionInfo) { 7195994201cSMaheshRavishankar SmallVector<int64_t> expandedShape; 7205994201cSMaheshRavishankar for (AffineExpr expr : indexingMap.getResults()) { 7211609f1c2Slong.chen unsigned dim = cast<AffineDimExpr>(expr).getPosition(); 7225994201cSMaheshRavishankar auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 7235994201cSMaheshRavishankar expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 7245994201cSMaheshRavishankar } 7255994201cSMaheshRavishankar return RankedTensorType::get(expandedShape, originalType.getElementType()); 7265994201cSMaheshRavishankar } 7275994201cSMaheshRavishankar 728206365bfSAlexander Belyaev /// Returns the reassociation maps to use in the `tensor.expand_shape` 7295994201cSMaheshRavishankar /// operation to convert the operands of the original operation to operands of 7305994201cSMaheshRavishankar /// the expanded operation. The same method is used to compute the 731206365bfSAlexander Belyaev /// `tensor.collapse_shape` used to collapse the result of the expanded 7325994201cSMaheshRavishankar /// op to get the value that can replace all uses of the results of the original 7335994201cSMaheshRavishankar /// op. 7345994201cSMaheshRavishankar static SmallVector<ReassociationIndices> 7355994201cSMaheshRavishankar getReassociationForExpansion(AffineMap indexingMap, 7365994201cSMaheshRavishankar const ExpansionInfo &expansionInfo) { 7375994201cSMaheshRavishankar SmallVector<ReassociationIndices> reassociation; 7385994201cSMaheshRavishankar unsigned numReshapeDims = 0; 7395994201cSMaheshRavishankar for (AffineExpr expr : indexingMap.getResults()) { 7401609f1c2Slong.chen unsigned dim = cast<AffineDimExpr>(expr).getPosition(); 7415994201cSMaheshRavishankar auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 7425994201cSMaheshRavishankar SmallVector<int64_t, 2> indices = llvm::to_vector<2>( 7435994201cSMaheshRavishankar llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 7445994201cSMaheshRavishankar reassociation.emplace_back(std::move(indices)); 7455994201cSMaheshRavishankar numReshapeDims += numExpandedDims; 7465994201cSMaheshRavishankar } 7475994201cSMaheshRavishankar return reassociation; 7485994201cSMaheshRavishankar } 7495994201cSMaheshRavishankar 7505994201cSMaheshRavishankar /// Update the body of an expanded linalg operation having index semantics. The 7515994201cSMaheshRavishankar /// indices of the original operation need to be recovered by linearizing the 7525994201cSMaheshRavishankar /// indices of the correspoding dimensions of the expanded operation. For now it 7535994201cSMaheshRavishankar /// is assumed that the shapes of the expanded operation needed for 7545994201cSMaheshRavishankar /// linearization are static. 7555994201cSMaheshRavishankar static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, 7565994201cSMaheshRavishankar Location loc, Region &fusedRegion, 7575994201cSMaheshRavishankar const ExpansionInfo &expansionInfo) { 7585994201cSMaheshRavishankar // Replace the original indices by the linearization of the expanded indices. 7595994201cSMaheshRavishankar for (IndexOp indexOp : 7605994201cSMaheshRavishankar llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) { 7615994201cSMaheshRavishankar ArrayRef<int64_t> expandedDims = 762d3b3f765SJacques Pienaar expansionInfo.getExpandedDims(indexOp.getDim()); 7635994201cSMaheshRavishankar assert(!expandedDims.empty() && "expected valid expansion info"); 7645994201cSMaheshRavishankar 7655994201cSMaheshRavishankar // Skip index operations that are not affected by the expansion. 7665994201cSMaheshRavishankar if (expandedDims.size() == 1 && 767d3b3f765SJacques Pienaar expandedDims.front() == (int64_t)indexOp.getDim()) 7685994201cSMaheshRavishankar continue; 7695994201cSMaheshRavishankar 7705994201cSMaheshRavishankar // Linearize the expanded indices of the original index dimension. 7715994201cSMaheshRavishankar OpBuilder::InsertionGuard guard(rewriter); 7725994201cSMaheshRavishankar rewriter.setInsertionPointAfter(indexOp); 7735994201cSMaheshRavishankar ArrayRef<int64_t> expandedDimsShape = 774d3b3f765SJacques Pienaar expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front(); 7755994201cSMaheshRavishankar SmallVector<Value> expandedIndices; 7765994201cSMaheshRavishankar expandedIndices.reserve(expandedDims.size() - 1); 7775994201cSMaheshRavishankar llvm::transform( 7785994201cSMaheshRavishankar expandedDims.drop_front(), std::back_inserter(expandedIndices), 7795994201cSMaheshRavishankar [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); 7805994201cSMaheshRavishankar Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); 7815994201cSMaheshRavishankar for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { 7825994201cSMaheshRavishankar assert(!ShapedType::isDynamic(std::get<0>(it))); 7835994201cSMaheshRavishankar AffineExpr idx, acc; 7845994201cSMaheshRavishankar bindDims(rewriter.getContext(), idx, acc); 7854c48f016SMatthias Springer newIndex = rewriter.create<affine::AffineApplyOp>( 7865994201cSMaheshRavishankar indexOp.getLoc(), idx + acc * std::get<0>(it), 7875994201cSMaheshRavishankar ValueRange{std::get<1>(it), newIndex}); 7885994201cSMaheshRavishankar } 7895994201cSMaheshRavishankar rewriter.replaceOp(indexOp, newIndex); 7905994201cSMaheshRavishankar } 7915994201cSMaheshRavishankar } 7925994201cSMaheshRavishankar 79397069a86SGaurav Shukla /// Checks if a single dynamic dimension expanded into multiple dynamic 79497069a86SGaurav Shukla /// dimensions. 79597069a86SGaurav Shukla static LogicalResult 79697069a86SGaurav Shukla validateDynamicDimExpansion(LinalgOp linalgOp, 79797069a86SGaurav Shukla const ExpansionInfo &expansionInfo, 79897069a86SGaurav Shukla PatternRewriter &rewriter) { 79997069a86SGaurav Shukla for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 80097069a86SGaurav Shukla ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 80197069a86SGaurav Shukla if (expandedShape.size() == 1) 80297069a86SGaurav Shukla continue; 80397069a86SGaurav Shukla bool foundDynamic = false; 80497069a86SGaurav Shukla for (int64_t shape : expandedShape) { 80597069a86SGaurav Shukla if (!ShapedType::isDynamic(shape)) 80697069a86SGaurav Shukla continue; 80797069a86SGaurav Shukla if (foundDynamic) { 80897069a86SGaurav Shukla return rewriter.notifyMatchFailure( 80997069a86SGaurav Shukla linalgOp, "cannot infer expanded shape with multiple dynamic " 81097069a86SGaurav Shukla "dims in the same reassociation group"); 81197069a86SGaurav Shukla } 81297069a86SGaurav Shukla foundDynamic = true; 81397069a86SGaurav Shukla } 81497069a86SGaurav Shukla } 81597069a86SGaurav Shukla return success(); 81697069a86SGaurav Shukla } 81797069a86SGaurav Shukla 8181ad9b266Slorenzo chelini /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op 8195994201cSMaheshRavishankar /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes 8205994201cSMaheshRavishankar /// that those conditions have been satisfied. 8210a81ace0SKazu Hirata static std::optional<SmallVector<Value>> 8223f18f6a2SQuinn Dawkins fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, 8235994201cSMaheshRavishankar OpOperand *fusableOpOperand, 8245994201cSMaheshRavishankar PatternRewriter &rewriter) { 8253f18f6a2SQuinn Dawkins assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) && 8265994201cSMaheshRavishankar "preconditions for fuse operation failed"); 82797069a86SGaurav Shukla 82897069a86SGaurav Shukla Location loc = linalgOp.getLoc(); 8295994201cSMaheshRavishankar // Check if reshape is expanding or collapsing. 830b618880eSAlexander Belyaev auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp); 831b618880eSAlexander Belyaev auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp); 8325994201cSMaheshRavishankar bool isExpanding = (expandingReshapeOp != nullptr); 8335994201cSMaheshRavishankar RankedTensorType expandedType = isExpanding 8345994201cSMaheshRavishankar ? expandingReshapeOp.getResultType() 8355994201cSMaheshRavishankar : collapsingReshapeOp.getSrcType(); 8364317a3dfSMaheshRavishankar RankedTensorType collapsedType = isExpanding 8374317a3dfSMaheshRavishankar ? expandingReshapeOp.getSrcType() 8384317a3dfSMaheshRavishankar : collapsingReshapeOp.getResultType(); 8395994201cSMaheshRavishankar 8405994201cSMaheshRavishankar ExpansionInfo expansionInfo; 8415994201cSMaheshRavishankar if (failed(expansionInfo.compute( 8423f18f6a2SQuinn Dawkins linalgOp, fusableOpOperand, 8435994201cSMaheshRavishankar isExpanding ? expandingReshapeOp.getReassociationMaps() 8445994201cSMaheshRavishankar : collapsingReshapeOp.getReassociationMaps(), 8454317a3dfSMaheshRavishankar expandedType.getShape(), collapsedType.getShape(), rewriter))) 8461a36588eSKazu Hirata return std::nullopt; 8475994201cSMaheshRavishankar 84897069a86SGaurav Shukla // TODO: With the support of multiple dynamic dims expansion in 84997069a86SGaurav Shukla // tensor.expand_shape op, this case can be handled. 85097069a86SGaurav Shukla if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter))) 85197069a86SGaurav Shukla return std::nullopt; 85297069a86SGaurav Shukla 8533f18f6a2SQuinn Dawkins if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter))) 8541a36588eSKazu Hirata return std::nullopt; 8555994201cSMaheshRavishankar 8565994201cSMaheshRavishankar SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 8573f18f6a2SQuinn Dawkins llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) { 8585994201cSMaheshRavishankar return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 8595994201cSMaheshRavishankar })); 8605994201cSMaheshRavishankar 861a7bfdc23SMahesh Ravishankar // Set insertion point to the generic op. 862a7bfdc23SMahesh Ravishankar OpBuilder::InsertionGuard g(rewriter); 8633f18f6a2SQuinn Dawkins rewriter.setInsertionPoint(linalgOp); 864a7bfdc23SMahesh Ravishankar 8655994201cSMaheshRavishankar SmallVector<Value> expandedOpOperands; 8663f18f6a2SQuinn Dawkins expandedOpOperands.reserve(linalgOp.getNumDpsInputs()); 8673f18f6a2SQuinn Dawkins for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) { 8685994201cSMaheshRavishankar if (opOperand == fusableOpOperand) { 86904235d07SJacques Pienaar expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc() 87004235d07SJacques Pienaar : collapsingReshapeOp.getSrc()); 8715994201cSMaheshRavishankar continue; 8725994201cSMaheshRavishankar } 873a7cccb9cSAlexander Belyaev if (auto opOperandType = 8745550c821STres Popp dyn_cast<RankedTensorType>(opOperand->get().getType())) { 8753f18f6a2SQuinn Dawkins AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); 8765994201cSMaheshRavishankar RankedTensorType expandedOperandType = 877ff5de8a9SBenjamin Kramer getExpandedType(opOperandType, indexingMap, expansionInfo); 8785994201cSMaheshRavishankar if (expandedOperandType != opOperand->get().getType()) { 8795994201cSMaheshRavishankar // Reshape the operand to get the right type. 8805994201cSMaheshRavishankar SmallVector<ReassociationIndices> reassociation = 8815994201cSMaheshRavishankar getReassociationForExpansion(indexingMap, expansionInfo); 882ff5de8a9SBenjamin Kramer if (failed(reshapeLikeShapesAreCompatible( 883ff5de8a9SBenjamin Kramer [&](const Twine &msg) { 8843f18f6a2SQuinn Dawkins return rewriter.notifyMatchFailure(linalgOp, msg); 885ff5de8a9SBenjamin Kramer }, 886ff5de8a9SBenjamin Kramer opOperandType.getShape(), expandedOperandType.getShape(), 887ff5de8a9SBenjamin Kramer reassociation, 888ff5de8a9SBenjamin Kramer /*isExpandingReshape=*/true))) 8891a36588eSKazu Hirata return std::nullopt; 890b618880eSAlexander Belyaev expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>( 89197069a86SGaurav Shukla loc, expandedOperandType, opOperand->get(), reassociation)); 8925994201cSMaheshRavishankar continue; 8935994201cSMaheshRavishankar } 8945994201cSMaheshRavishankar } 8955994201cSMaheshRavishankar expandedOpOperands.push_back(opOperand->get()); 8965994201cSMaheshRavishankar } 8975994201cSMaheshRavishankar 8985994201cSMaheshRavishankar SmallVector<Value> outputs; 8993f18f6a2SQuinn Dawkins for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) { 9003f18f6a2SQuinn Dawkins AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); 9010b2197b0SMatthias Springer auto opOperandType = cast<RankedTensorType>(opOperand.get().getType()); 9025994201cSMaheshRavishankar RankedTensorType expandedOutputType = 903ff5de8a9SBenjamin Kramer getExpandedType(opOperandType, indexingMap, expansionInfo); 9040b2197b0SMatthias Springer if (expandedOutputType != opOperand.get().getType()) { 9055994201cSMaheshRavishankar SmallVector<ReassociationIndices> reassociation = 9065994201cSMaheshRavishankar getReassociationForExpansion(indexingMap, expansionInfo); 907ff5de8a9SBenjamin Kramer if (failed(reshapeLikeShapesAreCompatible( 908ff5de8a9SBenjamin Kramer [&](const Twine &msg) { 9093f18f6a2SQuinn Dawkins return rewriter.notifyMatchFailure(linalgOp, msg); 910ff5de8a9SBenjamin Kramer }, 911ff5de8a9SBenjamin Kramer opOperandType.getShape(), expandedOutputType.getShape(), 912ff5de8a9SBenjamin Kramer reassociation, 913ff5de8a9SBenjamin Kramer /*isExpandingReshape=*/true))) 9141a36588eSKazu Hirata return std::nullopt; 915b618880eSAlexander Belyaev outputs.push_back(rewriter.create<tensor::ExpandShapeOp>( 91697069a86SGaurav Shukla loc, expandedOutputType, opOperand.get(), reassociation)); 91771604f4cSMahesh Ravishankar } else { 9180b2197b0SMatthias Springer outputs.push_back(opOperand.get()); 9195994201cSMaheshRavishankar } 9205994201cSMaheshRavishankar } 9215994201cSMaheshRavishankar 9225994201cSMaheshRavishankar // The iterator types of the expanded op are all parallel. 923e6598b05SOleg Shyshkov SmallVector<utils::IteratorType> iteratorTypes( 924e6598b05SOleg Shyshkov expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel); 9253f18f6a2SQuinn Dawkins for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray())) 9263f18f6a2SQuinn Dawkins for (auto j : expansionInfo.getExpandedDims(i)) 9273f18f6a2SQuinn Dawkins iteratorTypes[j] = type; 9285994201cSMaheshRavishankar 9295994201cSMaheshRavishankar TypeRange resultTypes = ValueRange(outputs).getTypes(); 9305994201cSMaheshRavishankar auto fusedOp = 9313f18f6a2SQuinn Dawkins rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes, 9325994201cSMaheshRavishankar /*inputs=*/expandedOpOperands, outputs, 9335994201cSMaheshRavishankar expandedOpIndexingMaps, iteratorTypes); 9345994201cSMaheshRavishankar Region &fusedRegion = fusedOp->getRegion(0); 9353f18f6a2SQuinn Dawkins Region &originalRegion = linalgOp->getRegion(0); 9365994201cSMaheshRavishankar rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); 9375994201cSMaheshRavishankar 9385994201cSMaheshRavishankar // Update the index accesses after the expansion. 9395994201cSMaheshRavishankar updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); 9405994201cSMaheshRavishankar 9415994201cSMaheshRavishankar // Reshape the result values to their original shape if this is a collapsing 9425994201cSMaheshRavishankar // reshape folded into its consumer. 9435994201cSMaheshRavishankar SmallVector<Value> resultVals; 9443f18f6a2SQuinn Dawkins for (OpResult opResult : linalgOp->getOpResults()) { 9455994201cSMaheshRavishankar int64_t resultNumber = opResult.getResultNumber(); 946a7bfdc23SMahesh Ravishankar if (resultTypes[resultNumber] != opResult.getType()) { 9475994201cSMaheshRavishankar SmallVector<ReassociationIndices> reassociation = 9485994201cSMaheshRavishankar getReassociationForExpansion( 9493f18f6a2SQuinn Dawkins linalgOp.getMatchingIndexingMap( 9503f18f6a2SQuinn Dawkins linalgOp.getDpsInitOperand(resultNumber)), 9515994201cSMaheshRavishankar expansionInfo); 952b618880eSAlexander Belyaev resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>( 9533f18f6a2SQuinn Dawkins linalgOp.getLoc(), opResult.getType(), 9545994201cSMaheshRavishankar fusedOp->getResult(resultNumber), reassociation)); 9555994201cSMaheshRavishankar } else { 9565994201cSMaheshRavishankar resultVals.push_back(fusedOp->getResult(resultNumber)); 9575994201cSMaheshRavishankar } 9585994201cSMaheshRavishankar } 9595994201cSMaheshRavishankar // Assuming a single result. 9605994201cSMaheshRavishankar return resultVals; 9615994201cSMaheshRavishankar } 9625994201cSMaheshRavishankar 9635994201cSMaheshRavishankar namespace { 9645994201cSMaheshRavishankar 9653f18f6a2SQuinn Dawkins /// Pattern to fuse a tensor.collapse_shape op with its consumer structured op, 96632288d37SMahesh Ravishankar /// when the reshape op is collapsing dimensions. The dimensionality of the loop 96732288d37SMahesh Ravishankar /// in the consumer is expanded. 96832288d37SMahesh Ravishankar class FoldWithProducerReshapeOpByExpansion 9693f18f6a2SQuinn Dawkins : public OpInterfaceRewritePattern<LinalgOp> { 97032288d37SMahesh Ravishankar public: 9712291705dSMahesh Ravishankar FoldWithProducerReshapeOpByExpansion(MLIRContext *context, 9722291705dSMahesh Ravishankar ControlFusionFn foldReshapes, 97332288d37SMahesh Ravishankar PatternBenefit benefit = 1) 9743f18f6a2SQuinn Dawkins : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 97532288d37SMahesh Ravishankar controlFoldingReshapes(std::move(foldReshapes)) {} 9765994201cSMaheshRavishankar 9773f18f6a2SQuinn Dawkins LogicalResult matchAndRewrite(LinalgOp linalgOp, 9785994201cSMaheshRavishankar PatternRewriter &rewriter) const override { 9793f18f6a2SQuinn Dawkins for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) { 98032288d37SMahesh Ravishankar tensor::CollapseShapeOp reshapeOp = 98132288d37SMahesh Ravishankar opOperand->get().getDefiningOp<tensor::CollapseShapeOp>(); 9825994201cSMaheshRavishankar if (!reshapeOp) 9835994201cSMaheshRavishankar continue; 98432288d37SMahesh Ravishankar // Fold only if 98532288d37SMahesh Ravishankar // - The tensor reshape op is folding. 98632288d37SMahesh Ravishankar // - All constraints of fusing with reshape by expansion are met. 9873f18f6a2SQuinn Dawkins if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) || 988a7bfdc23SMahesh Ravishankar (!controlFoldingReshapes(opOperand))) 9895994201cSMaheshRavishankar continue; 9905994201cSMaheshRavishankar 9910a81ace0SKazu Hirata std::optional<SmallVector<Value>> replacementValues = 9923f18f6a2SQuinn Dawkins fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter); 99332288d37SMahesh Ravishankar if (!replacementValues) 9945994201cSMaheshRavishankar return failure(); 9953f18f6a2SQuinn Dawkins rewriter.replaceOp(linalgOp, *replacementValues); 9965994201cSMaheshRavishankar return success(); 9975994201cSMaheshRavishankar } 9985994201cSMaheshRavishankar return failure(); 9995994201cSMaheshRavishankar } 100032288d37SMahesh Ravishankar 100132288d37SMahesh Ravishankar private: 10022291705dSMahesh Ravishankar ControlFusionFn controlFoldingReshapes; 10035994201cSMaheshRavishankar }; 10045994201cSMaheshRavishankar 1005c886d66dSMax191 class FoldPadWithProducerReshapeOpByExpansion 1006c886d66dSMax191 : public OpRewritePattern<tensor::PadOp> { 1007c886d66dSMax191 public: 1008c886d66dSMax191 FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context, 1009c886d66dSMax191 ControlFusionFn foldReshapes, 1010c886d66dSMax191 PatternBenefit benefit = 1) 1011c886d66dSMax191 : OpRewritePattern<tensor::PadOp>(context, benefit), 1012c886d66dSMax191 controlFoldingReshapes(std::move(foldReshapes)) {} 1013c886d66dSMax191 1014c886d66dSMax191 LogicalResult matchAndRewrite(tensor::PadOp padOp, 1015c886d66dSMax191 PatternRewriter &rewriter) const override { 1016c886d66dSMax191 tensor::CollapseShapeOp reshapeOp = 1017c886d66dSMax191 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); 1018c886d66dSMax191 if (!reshapeOp) 1019c886d66dSMax191 return failure(); 1020c886d66dSMax191 if (!reshapeOp->hasOneUse()) 1021c886d66dSMax191 return failure(); 1022c886d66dSMax191 1023c886d66dSMax191 if (!controlFoldingReshapes(&padOp.getSourceMutable())) { 1024c886d66dSMax191 return rewriter.notifyMatchFailure(padOp, 1025c886d66dSMax191 "fusion blocked by control function"); 1026c886d66dSMax191 } 1027c886d66dSMax191 1028c886d66dSMax191 ArrayRef<int64_t> low = padOp.getStaticLow(); 1029c886d66dSMax191 ArrayRef<int64_t> high = padOp.getStaticHigh(); 1030c886d66dSMax191 SmallVector<ReassociationIndices> reassociations = 1031c886d66dSMax191 reshapeOp.getReassociationIndices(); 1032c886d66dSMax191 1033c886d66dSMax191 for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { 1034c886d66dSMax191 if (reInd.size() != 1 && (l != 0 || h != 0)) 1035c886d66dSMax191 return failure(); 1036c886d66dSMax191 } 1037c886d66dSMax191 1038c886d66dSMax191 SmallVector<OpFoldResult> newLow, newHigh; 1039c886d66dSMax191 RankedTensorType expandedType = reshapeOp.getSrcType(); 1040c886d66dSMax191 RankedTensorType paddedType = padOp.getResultType(); 1041c886d66dSMax191 SmallVector<int64_t> expandedPaddedShape(expandedType.getShape()); 1042c886d66dSMax191 for (auto [idx, reInd] : llvm::enumerate(reassociations)) { 1043c886d66dSMax191 if (reInd.size() == 1) { 1044c886d66dSMax191 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx]; 1045c886d66dSMax191 } 1046c886d66dSMax191 for (size_t i = 0; i < reInd.size(); ++i) { 1047c886d66dSMax191 newLow.push_back(padOp.getMixedLowPad()[idx]); 1048c886d66dSMax191 newHigh.push_back(padOp.getMixedHighPad()[idx]); 1049c886d66dSMax191 } 1050c886d66dSMax191 } 1051c886d66dSMax191 1052c886d66dSMax191 Location loc = padOp->getLoc(); 1053c886d66dSMax191 RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape); 1054c886d66dSMax191 auto newPadOp = rewriter.create<tensor::PadOp>( 1055c886d66dSMax191 loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh, 1056c886d66dSMax191 padOp.getConstantPaddingValue(), padOp.getNofold()); 1057c886d66dSMax191 1058c886d66dSMax191 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( 1059c886d66dSMax191 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); 1060c886d66dSMax191 1061c886d66dSMax191 return success(); 1062c886d66dSMax191 } 1063c886d66dSMax191 1064c886d66dSMax191 private: 1065c886d66dSMax191 ControlFusionFn controlFoldingReshapes; 1066c886d66dSMax191 }; 1067c886d66dSMax191 10681ad9b266Slorenzo chelini /// Pattern to fold a tensor.expand_shape op with its producer generic op 106932288d37SMahesh Ravishankar /// by expanding the dimensionality of the loop in the producer op. 107032288d37SMahesh Ravishankar struct FoldReshapeWithGenericOpByExpansion 107132288d37SMahesh Ravishankar : public OpRewritePattern<tensor::ExpandShapeOp> { 107232288d37SMahesh Ravishankar 10732291705dSMahesh Ravishankar FoldReshapeWithGenericOpByExpansion(MLIRContext *context, 10742291705dSMahesh Ravishankar ControlFusionFn foldReshapes, 107532288d37SMahesh Ravishankar PatternBenefit benefit = 1) 107632288d37SMahesh Ravishankar : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), 107732288d37SMahesh Ravishankar controlFoldingReshapes(std::move(foldReshapes)) {} 107832288d37SMahesh Ravishankar 107932288d37SMahesh Ravishankar LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, 108032288d37SMahesh Ravishankar PatternRewriter &rewriter) const override { 108132288d37SMahesh Ravishankar // Fold only if all constraints of fusing with reshape by expansion are met. 10825550c821STres Popp auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc()); 1083a7bfdc23SMahesh Ravishankar if (!producerResult) { 1084a7bfdc23SMahesh Ravishankar return rewriter.notifyMatchFailure(reshapeOp, 1085a7bfdc23SMahesh Ravishankar "source not produced by an operation"); 1086a7bfdc23SMahesh Ravishankar } 1087a7bfdc23SMahesh Ravishankar 10883f18f6a2SQuinn Dawkins auto producer = dyn_cast<LinalgOp>(producerResult.getOwner()); 1089a7bfdc23SMahesh Ravishankar if (!producer) { 1090a7bfdc23SMahesh Ravishankar return rewriter.notifyMatchFailure(reshapeOp, 1091a7bfdc23SMahesh Ravishankar "producer not a generic op"); 1092a7bfdc23SMahesh Ravishankar } 1093a7bfdc23SMahesh Ravishankar 1094a7bfdc23SMahesh Ravishankar if (!isFusableWithReshapeByDimExpansion( 1095a7bfdc23SMahesh Ravishankar producer, 1096b4db15a9SAlexander Belyaev producer.getDpsInitOperand(producerResult.getResultNumber()))) { 1097a7bfdc23SMahesh Ravishankar return rewriter.notifyMatchFailure( 1098a7bfdc23SMahesh Ravishankar reshapeOp, "failed preconditions of fusion with producer generic op"); 1099a7bfdc23SMahesh Ravishankar } 1100a7bfdc23SMahesh Ravishankar 11018823e961SMatthias Springer if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) { 1102a7bfdc23SMahesh Ravishankar return rewriter.notifyMatchFailure(reshapeOp, 1103a7bfdc23SMahesh Ravishankar "fusion blocked by control function"); 1104a7bfdc23SMahesh Ravishankar } 1105a7bfdc23SMahesh Ravishankar 11060a81ace0SKazu Hirata std::optional<SmallVector<Value>> replacementValues = 11070a81ace0SKazu Hirata fuseWithReshapeByExpansion( 1108a7bfdc23SMahesh Ravishankar producer, reshapeOp, 11090a81ace0SKazu Hirata producer.getDpsInitOperand(producerResult.getResultNumber()), 11100a81ace0SKazu Hirata rewriter); 1111a7bfdc23SMahesh Ravishankar if (!replacementValues) { 1112a7bfdc23SMahesh Ravishankar return rewriter.notifyMatchFailure(reshapeOp, 1113a7bfdc23SMahesh Ravishankar "fusion by expansion failed"); 1114a7bfdc23SMahesh Ravishankar } 1115a7bfdc23SMahesh Ravishankar 1116a7bfdc23SMahesh Ravishankar // Find the replacement for the reshape op. Since the replacements have the 1117a7bfdc23SMahesh Ravishankar // same type as the returns of the original generic op, the consumer reshape 1118a7bfdc23SMahesh Ravishankar // op can be replaced by the source of the collapse_shape op that defines 1119a7bfdc23SMahesh Ravishankar // the replacement. 11205550c821STres Popp Value reshapeReplacement = 11215550c821STres Popp (*replacementValues)[cast<OpResult>(reshapeOp.getSrc()) 11225550c821STres Popp .getResultNumber()]; 1123a7bfdc23SMahesh Ravishankar if (auto collapseOp = 1124a7bfdc23SMahesh Ravishankar reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) { 1125a7bfdc23SMahesh Ravishankar reshapeReplacement = collapseOp.getSrc(); 1126a7bfdc23SMahesh Ravishankar } 1127a7bfdc23SMahesh Ravishankar rewriter.replaceOp(reshapeOp, reshapeReplacement); 1128a7bfdc23SMahesh Ravishankar rewriter.replaceOp(producer, *replacementValues); 112932288d37SMahesh Ravishankar return success(); 113032288d37SMahesh Ravishankar } 113132288d37SMahesh Ravishankar 113232288d37SMahesh Ravishankar private: 11332291705dSMahesh Ravishankar ControlFusionFn controlFoldingReshapes; 113432288d37SMahesh Ravishankar }; 113532288d37SMahesh Ravishankar } // namespace 113632288d37SMahesh Ravishankar 113732288d37SMahesh Ravishankar //===---------------------------------------------------------------------===// 11382c58cde0SMahesh Ravishankar // Methods and patterns to fuse reshape with linalg.generic operations by 11392c58cde0SMahesh Ravishankar // contraction of dimensions. 11402c58cde0SMahesh Ravishankar //===---------------------------------------------------------------------===// 11412c58cde0SMahesh Ravishankar 1142b40e9013SMahesh Ravishankar /// For a given list of indices in the range of the `indexingMap` that are 1143192d9dd7SKazu Hirata /// folded, return the indices of the corresponding domain. Return 1144192d9dd7SKazu Hirata /// `std::nullopt` on failure. Ensures that all the elements of the returned 1145192d9dd7SKazu Hirata /// reassociation are distinct. 1146b40e9013SMahesh Ravishankar static ReassociationIndices 11472c58cde0SMahesh Ravishankar getDomainReassociation(AffineMap indexingMap, 1148b40e9013SMahesh Ravishankar ReassociationIndicesRef rangeReassociation) { 11492c58cde0SMahesh Ravishankar assert(indexingMap.isProjectedPermutation() && 1150b40e9013SMahesh Ravishankar "expected projected permutation"); 11512c58cde0SMahesh Ravishankar 1152b40e9013SMahesh Ravishankar ReassociationIndices domainReassociation = llvm::to_vector<4>( 1153b40e9013SMahesh Ravishankar llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t { 11541609f1c2Slong.chen return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition(); 1155b40e9013SMahesh Ravishankar })); 1156b40e9013SMahesh Ravishankar // The projected permutation semantics ensures that there is no repetition of 1157b40e9013SMahesh Ravishankar // the domain indices. 11582c58cde0SMahesh Ravishankar return domainReassociation; 11592c58cde0SMahesh Ravishankar } 11602c58cde0SMahesh Ravishankar 11612c58cde0SMahesh Ravishankar /// For a given `dimSequence`, check if the sequence is conserved in the 11622c58cde0SMahesh Ravishankar /// `indexingMap`. `indexingMap` is expected to be a projected permutation. 11632c58cde0SMahesh Ravishankar /// Non-existence of the sequence returns true as well. 1164f12639d0SMahesh Ravishankar bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap, 11652c58cde0SMahesh Ravishankar ReassociationIndicesRef dimSequence) { 11662c58cde0SMahesh Ravishankar assert(!dimSequence.empty() && 11672c58cde0SMahesh Ravishankar "expected non-empty list for dimension sequence"); 11682c58cde0SMahesh Ravishankar assert(indexingMap.isProjectedPermutation() && 11692c58cde0SMahesh Ravishankar "expected indexing map to be projected permutation"); 11702c58cde0SMahesh Ravishankar 11712c58cde0SMahesh Ravishankar llvm::SmallDenseSet<unsigned, 4> sequenceElements; 11722c58cde0SMahesh Ravishankar sequenceElements.insert(dimSequence.begin(), dimSequence.end()); 11732c58cde0SMahesh Ravishankar 11742c58cde0SMahesh Ravishankar unsigned dimSequenceStart = dimSequence[0]; 1175a91ade0bSAdrian Kuegel for (const auto &expr : enumerate(indexingMap.getResults())) { 11761609f1c2Slong.chen unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition(); 11772c58cde0SMahesh Ravishankar // 1. Check if this start of the sequence. 11782c58cde0SMahesh Ravishankar if (dimInMapStart == dimSequenceStart) { 11792c58cde0SMahesh Ravishankar if (expr.index() + dimSequence.size() > indexingMap.getNumResults()) 11802c58cde0SMahesh Ravishankar return false; 11812c58cde0SMahesh Ravishankar // 1a. Check if sequence is preserved. 1182a91ade0bSAdrian Kuegel for (const auto &dimInSequence : enumerate(dimSequence)) { 11832c58cde0SMahesh Ravishankar unsigned dimInMap = 11841609f1c2Slong.chen cast<AffineDimExpr>( 11851609f1c2Slong.chen indexingMap.getResult(expr.index() + dimInSequence.index())) 11862c58cde0SMahesh Ravishankar .getPosition(); 11872c58cde0SMahesh Ravishankar if (dimInMap != dimInSequence.value()) 11882c58cde0SMahesh Ravishankar return false; 11892c58cde0SMahesh Ravishankar } 11902c58cde0SMahesh Ravishankar // Found the sequence. Projected permutation 11912c58cde0SMahesh Ravishankar // enforces that all AffineDimExprs in the result are unique, so no 11922c58cde0SMahesh Ravishankar // further checks are needed. 11932c58cde0SMahesh Ravishankar return true; 11942c58cde0SMahesh Ravishankar } 11952c58cde0SMahesh Ravishankar // 2. If position in the expr (which is of type AffineDimExpr) is part 11962c58cde0SMahesh Ravishankar // of sequence, return false here. This implies the entire sequence does not 11972c58cde0SMahesh Ravishankar // exist in the indexing map. 11982c58cde0SMahesh Ravishankar if (sequenceElements.count(dimInMapStart)) 11992c58cde0SMahesh Ravishankar return false; 12002c58cde0SMahesh Ravishankar } 12012c58cde0SMahesh Ravishankar // 3. No element of sequence found. Return true. 12022c58cde0SMahesh Ravishankar return true; 12032c58cde0SMahesh Ravishankar } 12042c58cde0SMahesh Ravishankar 1205f12639d0SMahesh Ravishankar bool mlir::linalg::areDimSequencesPreserved( 1206f12639d0SMahesh Ravishankar ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) { 1207f12639d0SMahesh Ravishankar return llvm::all_of(maps, [&](AffineMap map) { 1208f12639d0SMahesh Ravishankar return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) { 1209f12639d0SMahesh Ravishankar return isDimSequencePreserved(map, dimSequence); 1210f12639d0SMahesh Ravishankar }); 1211f12639d0SMahesh Ravishankar }); 1212f12639d0SMahesh Ravishankar } 1213f12639d0SMahesh Ravishankar 1214b40e9013SMahesh Ravishankar // Return the list of dimensions of the iteration domain that can be 1215b40e9013SMahesh Ravishankar // collapsed to allow for fusion with the a producer that is an expand_shape 1216b40e9013SMahesh Ravishankar // operation. If all dimensions created by expansion can be collapsed in the 1217b40e9013SMahesh Ravishankar // iteration space then the reshape is defunct. 1218b40e9013SMahesh Ravishankar // 1219b40e9013SMahesh Ravishankar // Example: 1220b40e9013SMahesh Ravishankar // 1221b40e9013SMahesh Ravishankar // ```mlir 1222b40e9013SMahesh Ravishankar // #map = affine_map<(d0, d1) -> (d0, d1)> 1223b40e9013SMahesh Ravishankar // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 122481ca5aa4SMatthias Springer // %2 = tensor.empty [..] : tensor<?x4xf32> 1225b40e9013SMahesh Ravishankar // %3 = linalg.generic { 1226b40e9013SMahesh Ravishankar // indexing_maps = [#map, #map], 1227b40e9013SMahesh Ravishankar // iterator_types = ["parallel" ,"parallel"]} 1228b40e9013SMahesh Ravishankar // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. } 1229b40e9013SMahesh Ravishankar // ``` 1230b40e9013SMahesh Ravishankar // 1231b40e9013SMahesh Ravishankar // can be fused by collapsing the dimensions of the iteration space. 1232b40e9013SMahesh Ravishankar // 1233b40e9013SMahesh Ravishankar // ```mlir 1234b40e9013SMahesh Ravishankar // #map = affine_map<(d0) -> (d0)> 123581ca5aa4SMatthias Springer // %2 = tensor.empty [..] : tensor<?xf32> 1236b40e9013SMahesh Ravishankar // %3 = linalg.generic { 1237b40e9013SMahesh Ravishankar // indexing_maps = [#map, #map], 1238b40e9013SMahesh Ravishankar // iterator_types = ["parallel"]} 1239b40e9013SMahesh Ravishankar // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. } 1240b40e9013SMahesh Ravishankar // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 1241b40e9013SMahesh Ravishankar // ``` 1242b40e9013SMahesh Ravishankar // 1243b40e9013SMahesh Ravishankar // In the following example, 1244b40e9013SMahesh Ravishankar // 1245b40e9013SMahesh Ravishankar // ```mlir 1246b40e9013SMahesh Ravishankar // #map0 = affine_map<(d0, d1) -> (d0, d1)> 1247b40e9013SMahesh Ravishankar // #map1 = affine_map<(d0, d1) -> (d1, d0)> 1248b40e9013SMahesh Ravishankar // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 124981ca5aa4SMatthias Springer // %2 = tensor.empty [..] : tensor<4x?xf32> 1250b40e9013SMahesh Ravishankar // %2 = linalg.generic { 1251b40e9013SMahesh Ravishankar // indexing_maps = [#map0, #map1], 1252b40e9013SMahesh Ravishankar // iterator_types = ["parallel" ,"parallel"]} 1253b40e9013SMahesh Ravishankar // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. } 1254b40e9013SMahesh Ravishankar // ``` 1255b40e9013SMahesh Ravishankar // 1256b40e9013SMahesh Ravishankar // the reshape cannot be fused with the generic op by collapsing the op 1257b40e9013SMahesh Ravishankar // dimensions since the indexing maps will have to contain mods and divs 1258b40e9013SMahesh Ravishankar // to preserve the accesses pattern. When no dimensions of the iteration 1259b40e9013SMahesh Ravishankar // space are collapsable and empty vector is returned. 1260b40e9013SMahesh Ravishankar static SmallVector<ReassociationIndices> 1261b40e9013SMahesh Ravishankar getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, 12622c58cde0SMahesh Ravishankar ArrayRef<ReassociationIndices> reassociation) { 12632c58cde0SMahesh Ravishankar // Some basic checks for this fusion to be valid. 126408efa230SMax191 if (!genericOp.hasPureTensorSemantics()) 1265b40e9013SMahesh Ravishankar return {}; 12662c58cde0SMahesh Ravishankar 1267d2c0572bSJacques Pienaar if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) { 12682c58cde0SMahesh Ravishankar return map.isProjectedPermutation(); 12692c58cde0SMahesh Ravishankar })) { 1270b40e9013SMahesh Ravishankar return {}; 12712c58cde0SMahesh Ravishankar } 12722c58cde0SMahesh Ravishankar 1273b40e9013SMahesh Ravishankar // Compute all the loops with the reduction iterator types. 1274c54bc8bdSOleg Shyshkov SmallVector<unsigned> reductionDims; 1275c54bc8bdSOleg Shyshkov genericOp.getReductionDims(reductionDims); 12762c58cde0SMahesh Ravishankar 1277b40e9013SMahesh Ravishankar llvm::SmallDenseSet<unsigned, 4> processedIterationDims; 12781227b8abSOleg Shyshkov AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand); 1279c54bc8bdSOleg Shyshkov auto iteratorTypes = genericOp.getIteratorTypesArray(); 1280b40e9013SMahesh Ravishankar SmallVector<ReassociationIndices> iterationSpaceReassociation; 1281b40e9013SMahesh Ravishankar for (ReassociationIndicesRef foldedRangeDims : reassociation) { 1282b40e9013SMahesh Ravishankar assert(!foldedRangeDims.empty() && "unexpected empty reassociation"); 1283b40e9013SMahesh Ravishankar 1284b40e9013SMahesh Ravishankar // Ignore dims that are not folded. 1285b40e9013SMahesh Ravishankar if (foldedRangeDims.size() == 1) 1286b40e9013SMahesh Ravishankar continue; 1287b40e9013SMahesh Ravishankar 1288b40e9013SMahesh Ravishankar ReassociationIndices foldedIterationSpaceDims = 1289b40e9013SMahesh Ravishankar getDomainReassociation(indexingMap, foldedRangeDims); 1290b40e9013SMahesh Ravishankar 1291b40e9013SMahesh Ravishankar // Check that the folded iteration dims do not contain already processed 1292b40e9013SMahesh Ravishankar // dims. 1293b40e9013SMahesh Ravishankar if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { 1294b40e9013SMahesh Ravishankar return processedIterationDims.count(dim); 12952c58cde0SMahesh Ravishankar })) 1296b40e9013SMahesh Ravishankar continue; 1297b40e9013SMahesh Ravishankar 1298b40e9013SMahesh Ravishankar // Check that all folded iterator types are all parallel or all reductions. 1299e6598b05SOleg Shyshkov utils::IteratorType startIteratorType = 1300e6598b05SOleg Shyshkov iteratorTypes[foldedIterationSpaceDims[0]]; 1301b40e9013SMahesh Ravishankar if (!isParallelIterator(startIteratorType) && 1302b40e9013SMahesh Ravishankar !isReductionIterator(startIteratorType)) 1303b40e9013SMahesh Ravishankar continue; 1304b40e9013SMahesh Ravishankar if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { 1305b40e9013SMahesh Ravishankar return iteratorTypes[dim] != startIteratorType; 1306b40e9013SMahesh Ravishankar })) 1307b40e9013SMahesh Ravishankar continue; 1308b40e9013SMahesh Ravishankar 1309b40e9013SMahesh Ravishankar // If the folded dimensions correspond to a "reduction" iterator type, 1310b40e9013SMahesh Ravishankar // the folded dimensions need to be "in-order". Strictly speaking this is 1311b40e9013SMahesh Ravishankar // not necessary, for reductions that are associative and commutative, but 1312b40e9013SMahesh Ravishankar // using a more strict definition of reduction for now. 1313b40e9013SMahesh Ravishankar if (isReductionIterator(startIteratorType)) { 1314b40e9013SMahesh Ravishankar bool isContiguous = false; 13156120bd47SMehdi Amini for (const auto &startDim : llvm::enumerate(reductionDims)) { 1316b40e9013SMahesh Ravishankar // Move window in `reductionDims` to start of the folded iteration dims. 1317b40e9013SMahesh Ravishankar if (startDim.value() != foldedIterationSpaceDims[0]) 1318b40e9013SMahesh Ravishankar continue; 1319b40e9013SMahesh Ravishankar // If sizes doesnt match, trivial not contiguous. This condition should 1320b40e9013SMahesh Ravishankar // not be hit. 1321b40e9013SMahesh Ravishankar if (startDim.index() + foldedIterationSpaceDims.size() > 1322b40e9013SMahesh Ravishankar reductionDims.size()) 1323b40e9013SMahesh Ravishankar break; 1324b40e9013SMahesh Ravishankar // Check that the contiguity is maintained. 1325b40e9013SMahesh Ravishankar isContiguous = true; 13266120bd47SMehdi Amini for (const auto &foldedDim : 13276120bd47SMehdi Amini llvm::enumerate(foldedIterationSpaceDims)) { 1328b40e9013SMahesh Ravishankar if (reductionDims[foldedDim.index() + startDim.index()] != 1329b40e9013SMahesh Ravishankar foldedDim.value()) { 1330b40e9013SMahesh Ravishankar isContiguous = false; 1331b40e9013SMahesh Ravishankar break; 13322c58cde0SMahesh Ravishankar } 1333b40e9013SMahesh Ravishankar } 1334b40e9013SMahesh Ravishankar break; 1335b40e9013SMahesh Ravishankar } 1336b40e9013SMahesh Ravishankar if (!isContiguous) 1337b40e9013SMahesh Ravishankar continue; 1338b40e9013SMahesh Ravishankar } 1339b40e9013SMahesh Ravishankar 1340b40e9013SMahesh Ravishankar // Check that the sequence is preserved in all indexing maps. 1341d2c0572bSJacques Pienaar if (llvm::any_of(genericOp.getIndexingMapsArray(), 1342d2c0572bSJacques Pienaar [&](AffineMap indexingMap) { 1343d2c0572bSJacques Pienaar return !isDimSequencePreserved(indexingMap, 1344d2c0572bSJacques Pienaar foldedIterationSpaceDims); 1345b40e9013SMahesh Ravishankar })) 1346b40e9013SMahesh Ravishankar continue; 1347b40e9013SMahesh Ravishankar 1348b40e9013SMahesh Ravishankar processedIterationDims.insert(foldedIterationSpaceDims.begin(), 1349b40e9013SMahesh Ravishankar foldedIterationSpaceDims.end()); 1350b40e9013SMahesh Ravishankar iterationSpaceReassociation.emplace_back( 1351b40e9013SMahesh Ravishankar std::move(foldedIterationSpaceDims)); 1352b40e9013SMahesh Ravishankar } 1353b40e9013SMahesh Ravishankar 1354b40e9013SMahesh Ravishankar return iterationSpaceReassociation; 13552c58cde0SMahesh Ravishankar } 13562c58cde0SMahesh Ravishankar 13572c58cde0SMahesh Ravishankar /// Helper class to carry state while collapsing the `linalg.generic` op. 13582c58cde0SMahesh Ravishankar namespace { 13592c58cde0SMahesh Ravishankar class CollapsingInfo { 13602c58cde0SMahesh Ravishankar public: 1361b40e9013SMahesh Ravishankar LogicalResult initialize(unsigned origNumLoops, 1362b40e9013SMahesh Ravishankar ArrayRef<ReassociationIndices> foldedIterationDims) { 1363b40e9013SMahesh Ravishankar llvm::SmallDenseSet<int64_t, 4> processedDims; 1364b40e9013SMahesh Ravishankar // Find all the dims that are folded. 1365b40e9013SMahesh Ravishankar for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) { 1366b40e9013SMahesh Ravishankar if (foldedIterationDim.empty()) 1367b40e9013SMahesh Ravishankar continue; 1368b40e9013SMahesh Ravishankar // If the folded dims contain dims already folded, that's illegal 1369b40e9013SMahesh Ravishankar // specification. Repetition within a list is also illegal. 1370b40e9013SMahesh Ravishankar for (auto dim : foldedIterationDim) { 1371b40e9013SMahesh Ravishankar if (dim >= origNumLoops) 1372b40e9013SMahesh Ravishankar return failure(); 1373b40e9013SMahesh Ravishankar if (processedDims.count(dim)) 1374b40e9013SMahesh Ravishankar return failure(); 1375b40e9013SMahesh Ravishankar processedDims.insert(dim); 13762c58cde0SMahesh Ravishankar } 1377b40e9013SMahesh Ravishankar collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(), 1378b40e9013SMahesh Ravishankar foldedIterationDim.end()); 1379b40e9013SMahesh Ravishankar } 1380b40e9013SMahesh Ravishankar if (processedDims.size() > origNumLoops) 1381b40e9013SMahesh Ravishankar return failure(); 1382b40e9013SMahesh Ravishankar 1383b40e9013SMahesh Ravishankar // Add all the preserved dims of the original op as single 1384b40e9013SMahesh Ravishankar // elements to `collapsedOpToOrigOpIterationDim`. 1385b40e9013SMahesh Ravishankar for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) { 1386b40e9013SMahesh Ravishankar if (processedDims.count(dim)) 1387b40e9013SMahesh Ravishankar continue; 1388b40e9013SMahesh Ravishankar collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim}); 13892c58cde0SMahesh Ravishankar } 13902c58cde0SMahesh Ravishankar 1391b40e9013SMahesh Ravishankar llvm::sort(collapsedOpToOrigOpIterationDim, 1392b40e9013SMahesh Ravishankar [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) { 1393b40e9013SMahesh Ravishankar return lhs[0] < rhs[0]; 1394b40e9013SMahesh Ravishankar }); 1395b40e9013SMahesh Ravishankar origOpToCollapsedOpIterationDim.resize(origNumLoops); 13966120bd47SMehdi Amini for (const auto &foldedDims : 13976120bd47SMehdi Amini llvm::enumerate(collapsedOpToOrigOpIterationDim)) { 13986120bd47SMehdi Amini for (const auto &dim : enumerate(foldedDims.value())) 1399b40e9013SMahesh Ravishankar origOpToCollapsedOpIterationDim[dim.value()] = 1400b40e9013SMahesh Ravishankar std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index()); 1401b40e9013SMahesh Ravishankar } 1402b40e9013SMahesh Ravishankar return success(); 14032c58cde0SMahesh Ravishankar } 14042c58cde0SMahesh Ravishankar 1405b40e9013SMahesh Ravishankar /// Return mapping from collapsed loop domain to original loop domain. 1406b40e9013SMahesh Ravishankar ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const { 1407b40e9013SMahesh Ravishankar return collapsedOpToOrigOpIterationDim; 14082c58cde0SMahesh Ravishankar } 14092c58cde0SMahesh Ravishankar 1410b40e9013SMahesh Ravishankar /// Return mapping from original loop domain to collapsed loop domain. The 1411b40e9013SMahesh Ravishankar /// mapping is a pair. First value is the dimension in the collapsed loop that 1412b40e9013SMahesh Ravishankar /// the original loop is mapped to. Second is the relative position in folded 1413b40e9013SMahesh Ravishankar /// list of this domain. For example if the original loop domain is 3D, and 1414b40e9013SMahesh Ravishankar /// the collapsed loop domain is folding all of it, i.e. 1415b40e9013SMahesh Ravishankar /// 1416b40e9013SMahesh Ravishankar /// ``` 1417b40e9013SMahesh Ravishankar /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]` 1418b40e9013SMahesh Ravishankar /// ``` 1419b40e9013SMahesh Ravishankar /// 1420b40e9013SMahesh Ravishankar /// then 1421b40e9013SMahesh Ravishankar /// 1422b40e9013SMahesh Ravishankar /// ``` 1423b40e9013SMahesh Ravishankar /// origOpToCollapsedOpMapping[0] = {0, 0}; 1424b40e9013SMahesh Ravishankar /// origOpToCollapsedOpMapping[1] = {0, 1}; 1425b40e9013SMahesh Ravishankar /// origOpToCollapsedOpMapping[2] = {0, 2}; 1426b40e9013SMahesh Ravishankar /// origOpToCollapsedOpMapping[3] = {1, 0}; 1427b40e9013SMahesh Ravishankar /// origOpToCollapsedOpMapping[4] = {1, 1}; 1428b40e9013SMahesh Ravishankar /// ``` 1429b40e9013SMahesh Ravishankar /// 1430b40e9013SMahesh Ravishankar ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const { 1431b40e9013SMahesh Ravishankar return origOpToCollapsedOpIterationDim; 14322c58cde0SMahesh Ravishankar } 14332c58cde0SMahesh Ravishankar 1434b40e9013SMahesh Ravishankar /// Return the collapsed op iteration domain rank. 1435b40e9013SMahesh Ravishankar unsigned getCollapsedOpIterationRank() const { 1436b40e9013SMahesh Ravishankar return collapsedOpToOrigOpIterationDim.size(); 14372c58cde0SMahesh Ravishankar } 14382c58cde0SMahesh Ravishankar 14392c58cde0SMahesh Ravishankar private: 1440b40e9013SMahesh Ravishankar /// Map from the iteration domain index in collapsed op to the iteration 1441b40e9013SMahesh Ravishankar /// domain indices in the original op. 1442b40e9013SMahesh Ravishankar SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim; 14432c58cde0SMahesh Ravishankar 1444b40e9013SMahesh Ravishankar /// Map from iteration domain index in the original op to the iteration domain 1445b40e9013SMahesh Ravishankar /// index in the collapsed op. 1446b40e9013SMahesh Ravishankar SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim; 14472c58cde0SMahesh Ravishankar }; 14482c58cde0SMahesh Ravishankar } // namespace 14492c58cde0SMahesh Ravishankar 14502c58cde0SMahesh Ravishankar /// Get the iterator types for the collapsed operation given the original 14512c58cde0SMahesh Ravishankar /// iterator types and collapsed dimensions. 1452e6598b05SOleg Shyshkov static SmallVector<utils::IteratorType> 1453e6598b05SOleg Shyshkov getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes, 1454b40e9013SMahesh Ravishankar const CollapsingInfo &collapsingInfo) { 1455e6598b05SOleg Shyshkov SmallVector<utils::IteratorType> collapsedIteratorTypes; 14562c58cde0SMahesh Ravishankar for (ReassociationIndicesRef foldedIterDims : 1457b40e9013SMahesh Ravishankar collapsingInfo.getCollapsedOpToOrigOpMapping()) { 14582c58cde0SMahesh Ravishankar assert(!foldedIterDims.empty() && 14592c58cde0SMahesh Ravishankar "reassociation indices expected to have non-empty sets"); 14602c58cde0SMahesh Ravishankar // Just pick the iterator type of the first folded dim. Pre-condition checks 14612c58cde0SMahesh Ravishankar // expected to have checked that iterator types of all folded dimensions are 14622c58cde0SMahesh Ravishankar // the same. 1463e6598b05SOleg Shyshkov collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]); 14642c58cde0SMahesh Ravishankar } 14652c58cde0SMahesh Ravishankar return collapsedIteratorTypes; 14662c58cde0SMahesh Ravishankar } 14672c58cde0SMahesh Ravishankar 14682c58cde0SMahesh Ravishankar /// Compute the indexing map in the collapsed op that corresponds to the given 14692c58cde0SMahesh Ravishankar /// `indexingMap` of the original operation. 1470b40e9013SMahesh Ravishankar static AffineMap 1471b40e9013SMahesh Ravishankar getCollapsedOpIndexingMap(AffineMap indexingMap, 1472b40e9013SMahesh Ravishankar const CollapsingInfo &collapsingInfo) { 14732c58cde0SMahesh Ravishankar MLIRContext *context = indexingMap.getContext(); 14742c58cde0SMahesh Ravishankar assert(indexingMap.isProjectedPermutation() && 14752c58cde0SMahesh Ravishankar "expected indexing map to be projected permutation"); 14762c58cde0SMahesh Ravishankar SmallVector<AffineExpr> resultExprs; 1477b40e9013SMahesh Ravishankar auto origOpToCollapsedOpMapping = 1478b40e9013SMahesh Ravishankar collapsingInfo.getOrigOpToCollapsedOpMapping(); 14792c58cde0SMahesh Ravishankar for (auto expr : indexingMap.getResults()) { 14801609f1c2Slong.chen unsigned dim = cast<AffineDimExpr>(expr).getPosition(); 1481b40e9013SMahesh Ravishankar // If the dim is not the first of the collapsed dim, do nothing. 1482b40e9013SMahesh Ravishankar if (origOpToCollapsedOpMapping[dim].second != 0) 1483b40e9013SMahesh Ravishankar continue; 1484b40e9013SMahesh Ravishankar // The next n-dims are guaranteed to be collapsed. So just use the 1485b40e9013SMahesh Ravishankar // iteration dimension of the collapsed op. 1486b40e9013SMahesh Ravishankar resultExprs.push_back( 1487b40e9013SMahesh Ravishankar getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context)); 14882c58cde0SMahesh Ravishankar } 1489b40e9013SMahesh Ravishankar return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0, 14902c58cde0SMahesh Ravishankar resultExprs, context); 14912c58cde0SMahesh Ravishankar } 14922c58cde0SMahesh Ravishankar 14932c58cde0SMahesh Ravishankar /// Return the `reassociation` indices to use to collapse the operand when the 14942c58cde0SMahesh Ravishankar /// iteration space of a generic op is collapsed. 14952c58cde0SMahesh Ravishankar static SmallVector<ReassociationIndices> 1496b40e9013SMahesh Ravishankar getOperandReassociation(AffineMap indexingMap, 1497b40e9013SMahesh Ravishankar const CollapsingInfo &collapsingInfo) { 14982c58cde0SMahesh Ravishankar unsigned counter = 0; 14992c58cde0SMahesh Ravishankar SmallVector<ReassociationIndices> operandReassociation; 1500b40e9013SMahesh Ravishankar auto origOpToCollapsedOpMapping = 1501b40e9013SMahesh Ravishankar collapsingInfo.getOrigOpToCollapsedOpMapping(); 1502b40e9013SMahesh Ravishankar auto collapsedOpToOrigOpMapping = 1503b40e9013SMahesh Ravishankar collapsingInfo.getCollapsedOpToOrigOpMapping(); 1504b40e9013SMahesh Ravishankar while (counter < indexingMap.getNumResults()) { 1505b40e9013SMahesh Ravishankar unsigned dim = 15061609f1c2Slong.chen cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition(); 1507b40e9013SMahesh Ravishankar // This is the start of a collapsed dimensions of the iteration that 1508b40e9013SMahesh Ravishankar // is gauranteed to be preserved in the indexing map. The number of folded 1509b40e9013SMahesh Ravishankar // dims is obtained from the collapsed op to original op mapping. 15102c58cde0SMahesh Ravishankar unsigned numFoldedDims = 1511b40e9013SMahesh Ravishankar collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first] 1512b40e9013SMahesh Ravishankar .size(); 151312cc8e73SGuray Ozen if (origOpToCollapsedOpMapping[dim].second == 0) { 15142c58cde0SMahesh Ravishankar auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims); 15152c58cde0SMahesh Ravishankar operandReassociation.emplace_back(range.begin(), range.end()); 15162c58cde0SMahesh Ravishankar } 151712cc8e73SGuray Ozen counter += numFoldedDims; 15182c58cde0SMahesh Ravishankar } 15192c58cde0SMahesh Ravishankar return operandReassociation; 15202c58cde0SMahesh Ravishankar } 15212c58cde0SMahesh Ravishankar 15222c58cde0SMahesh Ravishankar /// Get the new value to use for a given `OpOperand` in the collapsed operation. 15235c3ed392SAviad Cohen static Value getCollapsedOpOperand(Location loc, LinalgOp op, 15242c58cde0SMahesh Ravishankar OpOperand *opOperand, 1525b40e9013SMahesh Ravishankar const CollapsingInfo &collapsingInfo, 15262c58cde0SMahesh Ravishankar OpBuilder &builder) { 15275c3ed392SAviad Cohen AffineMap indexingMap = op.getMatchingIndexingMap(opOperand); 15282c58cde0SMahesh Ravishankar SmallVector<ReassociationIndices> operandReassociation = 15292c58cde0SMahesh Ravishankar getOperandReassociation(indexingMap, collapsingInfo); 15302c58cde0SMahesh Ravishankar 15315c3ed392SAviad Cohen // If the number of entries in the reassociation for the operand is same as 15325c3ed392SAviad Cohen // the number of results of the indexing map, then nothing to do for this 15335c3ed392SAviad Cohen // operand. 15342c58cde0SMahesh Ravishankar Value operand = opOperand->get(); 15352c58cde0SMahesh Ravishankar if (operandReassociation.size() == indexingMap.getNumResults()) 15362c58cde0SMahesh Ravishankar return operand; 15372c58cde0SMahesh Ravishankar 15382c58cde0SMahesh Ravishankar // Insert a reshape to collapse the dimensions. 1539d4ae7ee6SAviad Cohen if (isa<MemRefType>(operand.getType())) { 1540d4ae7ee6SAviad Cohen return builder 1541d4ae7ee6SAviad Cohen .create<memref::CollapseShapeOp>(loc, operand, operandReassociation) 1542d4ae7ee6SAviad Cohen .getResult(); 154346ce993dSMehdi Amini } 1544d4ae7ee6SAviad Cohen return builder 1545d4ae7ee6SAviad Cohen .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation) 1546d4ae7ee6SAviad Cohen .getResult(); 1547d4ae7ee6SAviad Cohen } 15482c58cde0SMahesh Ravishankar 15492c58cde0SMahesh Ravishankar /// Modify the `linalg.index` operations in the original generic op, to its 15502c58cde0SMahesh Ravishankar /// value in the collapsed operation. 15512c58cde0SMahesh Ravishankar void generateCollapsedIndexingRegion(Location loc, Block *block, 1552b40e9013SMahesh Ravishankar const CollapsingInfo &collapsingInfo, 15532c58cde0SMahesh Ravishankar ValueRange loopRange, 1554135977c9SGuray Ozen RewriterBase &rewriter) { 15552c58cde0SMahesh Ravishankar OpBuilder::InsertionGuard g(rewriter); 15562c58cde0SMahesh Ravishankar rewriter.setInsertionPointToStart(block); 15572c58cde0SMahesh Ravishankar 15582c58cde0SMahesh Ravishankar // Collect all the original index ops. 15592c58cde0SMahesh Ravishankar auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>()); 15602c58cde0SMahesh Ravishankar 15612c58cde0SMahesh Ravishankar // For each folded dimension list resolve the original induction variable 15622c58cde0SMahesh Ravishankar // values in terms of the folded dimension induction variable. 15632c58cde0SMahesh Ravishankar // i_{folded} = (i_0 * d1 + i1) * d2 + i2. 15642c58cde0SMahesh Ravishankar // can be inverted to 15652c58cde0SMahesh Ravishankar // i2 = i_{folded} % d2 15662c58cde0SMahesh Ravishankar // i1 = (i_{folded} / d2) % d1 15672c58cde0SMahesh Ravishankar // i0 = i_{folded} / (d1 * d2) 15682c58cde0SMahesh Ravishankar llvm::DenseMap<unsigned, Value> indexReplacementVals; 1569a0a76804SJakub Kuderski for (auto foldedDims : 1570b40e9013SMahesh Ravishankar enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) { 15712c58cde0SMahesh Ravishankar ReassociationIndicesRef foldedDimsRef(foldedDims.value()); 15722c58cde0SMahesh Ravishankar Value newIndexVal = 15732c58cde0SMahesh Ravishankar rewriter.create<linalg::IndexOp>(loc, foldedDims.index()); 15742c58cde0SMahesh Ravishankar for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) { 15752c58cde0SMahesh Ravishankar indexReplacementVals[dim] = 1576*1f5335c1SMaheshRavishankar rewriter.create<arith::RemSIOp>(loc, newIndexVal, loopRange[dim]); 15772c58cde0SMahesh Ravishankar newIndexVal = 1578*1f5335c1SMaheshRavishankar rewriter.create<arith::DivSIOp>(loc, newIndexVal, loopRange[dim]); 15792c58cde0SMahesh Ravishankar } 15802c58cde0SMahesh Ravishankar indexReplacementVals[foldedDims.value().front()] = newIndexVal; 15812c58cde0SMahesh Ravishankar } 15822c58cde0SMahesh Ravishankar 15832c58cde0SMahesh Ravishankar for (auto indexOp : indexOps) { 1584d3b3f765SJacques Pienaar auto dim = indexOp.getDim(); 15852c58cde0SMahesh Ravishankar rewriter.replaceOp(indexOp, indexReplacementVals[dim]); 15862c58cde0SMahesh Ravishankar } 15872c58cde0SMahesh Ravishankar } 15882c58cde0SMahesh Ravishankar 1589b6f4dd9eSsrcarroll void collapseOperandsAndResults(LinalgOp op, 15905c3ed392SAviad Cohen const CollapsingInfo &collapsingInfo, 1591b6f4dd9eSsrcarroll RewriterBase &rewriter, 1592b6f4dd9eSsrcarroll SmallVectorImpl<Value> &inputOperands, 1593b6f4dd9eSsrcarroll SmallVectorImpl<Value> &outputOperands, 1594b6f4dd9eSsrcarroll SmallVectorImpl<Type> &resultTypes) { 15955c3ed392SAviad Cohen Location loc = op->getLoc(); 1596b6f4dd9eSsrcarroll inputOperands = 15975c3ed392SAviad Cohen llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) { 15985c3ed392SAviad Cohen return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo, 15995c3ed392SAviad Cohen rewriter); 16005c3ed392SAviad Cohen }); 16015c3ed392SAviad Cohen 16025c3ed392SAviad Cohen // Get the output operands and result types. 16035c3ed392SAviad Cohen resultTypes.reserve(op.getNumDpsInits()); 16045c3ed392SAviad Cohen outputOperands.reserve(op.getNumDpsInits()); 16055c3ed392SAviad Cohen for (OpOperand &output : op.getDpsInitsMutable()) { 16065c3ed392SAviad Cohen Value newOutput = 16075c3ed392SAviad Cohen getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter); 16085c3ed392SAviad Cohen outputOperands.push_back(newOutput); 16095c3ed392SAviad Cohen // If the op has "buffer semantics", then the init operands are ranked 16105c3ed392SAviad Cohen // memrefs and the op has no results. 16110a8e3dd4SMatthias Springer if (!op.hasPureBufferSemantics()) 16125c3ed392SAviad Cohen resultTypes.push_back(newOutput.getType()); 16135c3ed392SAviad Cohen } 16145c3ed392SAviad Cohen } 16155c3ed392SAviad Cohen 1616b6f4dd9eSsrcarroll /// Clone a `LinalgOp` to a collapsed version of same name 1617b6f4dd9eSsrcarroll template <typename OpTy> 1618b6f4dd9eSsrcarroll OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, 1619b6f4dd9eSsrcarroll const CollapsingInfo &collapsingInfo) { 1620b6f4dd9eSsrcarroll return nullptr; 1621b6f4dd9eSsrcarroll } 16225c3ed392SAviad Cohen 1623b6f4dd9eSsrcarroll /// Collapse any `LinalgOp` that does not require any specialization such as 1624b6f4dd9eSsrcarroll /// indexing_maps, iterator_types, etc. 1625b6f4dd9eSsrcarroll template <> 1626b6f4dd9eSsrcarroll LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp, 1627b6f4dd9eSsrcarroll const CollapsingInfo &collapsingInfo) { 1628b6f4dd9eSsrcarroll SmallVector<Value> inputOperands, outputOperands; 1629b6f4dd9eSsrcarroll SmallVector<Type> resultTypes; 1630b6f4dd9eSsrcarroll collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands, 1631b6f4dd9eSsrcarroll outputOperands, resultTypes); 1632ccc02563SAviad Cohen 1633ccc02563SAviad Cohen return clone( 1634b6f4dd9eSsrcarroll rewriter, origOp, resultTypes, 1635ccc02563SAviad Cohen llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands))); 1636b6f4dd9eSsrcarroll } 1637b6f4dd9eSsrcarroll 1638b6f4dd9eSsrcarroll /// Collapse a `GenericOp` 1639b6f4dd9eSsrcarroll template <> 1640b6f4dd9eSsrcarroll GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter, 1641b6f4dd9eSsrcarroll GenericOp origOp, 1642b6f4dd9eSsrcarroll const CollapsingInfo &collapsingInfo) { 1643b6f4dd9eSsrcarroll SmallVector<Value> inputOperands, outputOperands; 1644b6f4dd9eSsrcarroll SmallVector<Type> resultTypes; 1645b6f4dd9eSsrcarroll collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands, 1646b6f4dd9eSsrcarroll outputOperands, resultTypes); 1647b6f4dd9eSsrcarroll SmallVector<AffineMap> indexingMaps( 1648b6f4dd9eSsrcarroll llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) { 16495c3ed392SAviad Cohen return getCollapsedOpIndexingMap(map, collapsingInfo); 1650b6f4dd9eSsrcarroll })); 16515c3ed392SAviad Cohen 1652b6f4dd9eSsrcarroll SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes( 1653b6f4dd9eSsrcarroll origOp.getIteratorTypesArray(), collapsingInfo)); 1654b6f4dd9eSsrcarroll 1655b6f4dd9eSsrcarroll GenericOp collapsedOp = rewriter.create<linalg::GenericOp>( 1656b6f4dd9eSsrcarroll origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps, 16575c3ed392SAviad Cohen iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); 1658b6f4dd9eSsrcarroll Block *origOpBlock = &origOp->getRegion(0).front(); 16595c3ed392SAviad Cohen Block *collapsedOpBlock = &collapsedOp->getRegion(0).front(); 16605c3ed392SAviad Cohen rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, 16615c3ed392SAviad Cohen collapsedOpBlock->getArguments()); 16625c3ed392SAviad Cohen return collapsedOp; 16635c3ed392SAviad Cohen } 16645c3ed392SAviad Cohen 1665b6f4dd9eSsrcarroll LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, 16665c3ed392SAviad Cohen RewriterBase &rewriter) { 1667b6f4dd9eSsrcarroll if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) { 1668b6f4dd9eSsrcarroll return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo); 1669b6f4dd9eSsrcarroll } else { 1670b6f4dd9eSsrcarroll return cloneToCollapsedOp(rewriter, op, collapsingInfo); 1671b6f4dd9eSsrcarroll } 1672b6f4dd9eSsrcarroll } 16735c3ed392SAviad Cohen 1674b6f4dd9eSsrcarroll /// Implementation of fusion with reshape operation by collapsing dimensions. 1675b6f4dd9eSsrcarroll FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims( 1676b6f4dd9eSsrcarroll LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims, 1677b6f4dd9eSsrcarroll RewriterBase &rewriter) { 1678b40e9013SMahesh Ravishankar // Bail on trivial no-op cases. 16795c3ed392SAviad Cohen if (op.getNumLoops() <= 1 || foldedIterationDims.empty() || 1680b40e9013SMahesh Ravishankar llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { 1681b40e9013SMahesh Ravishankar return foldedDims.size() <= 1; 1682b40e9013SMahesh Ravishankar })) 1683b40e9013SMahesh Ravishankar return failure(); 16842c58cde0SMahesh Ravishankar 16850a8e3dd4SMatthias Springer bool hasPureBufferSemantics = op.hasPureBufferSemantics(); 16860a8e3dd4SMatthias Springer if (hasPureBufferSemantics && 16875c3ed392SAviad Cohen !llvm::all_of(op->getOperands(), [&](Value operand) -> bool { 1688d4ae7ee6SAviad Cohen MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType()); 1689d4ae7ee6SAviad Cohen if (!memRefToCollapse) 1690d4ae7ee6SAviad Cohen return true; 1691d4ae7ee6SAviad Cohen 1692d4ae7ee6SAviad Cohen return memref::CollapseShapeOp::isGuaranteedCollapsible( 1693d4ae7ee6SAviad Cohen memRefToCollapse, foldedIterationDims); 1694d4ae7ee6SAviad Cohen })) 16955c3ed392SAviad Cohen return rewriter.notifyMatchFailure(op, 1696d4ae7ee6SAviad Cohen "memref is not guaranteed collapsible"); 1697d4ae7ee6SAviad Cohen 1698b40e9013SMahesh Ravishankar CollapsingInfo collapsingInfo; 16995c3ed392SAviad Cohen if (failed( 17005c3ed392SAviad Cohen collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) { 1701b40e9013SMahesh Ravishankar return rewriter.notifyMatchFailure( 17025c3ed392SAviad Cohen op, "illegal to collapse specified dimensions"); 1703b40e9013SMahesh Ravishankar } 17042c58cde0SMahesh Ravishankar 170570e99f38SAlex Zinenko // Bail on non-canonical ranges. 1706b6f4dd9eSsrcarroll SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc()); 170770e99f38SAlex Zinenko auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) { 170868f58812STres Popp if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) 17095550c821STres Popp return cast<IntegerAttr>(attr).getInt() == value; 171070e99f38SAlex Zinenko llvm::APInt actual; 17114f279a57SKazu Hirata return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) && 171270e99f38SAlex Zinenko actual.getSExtValue() == value; 171370e99f38SAlex Zinenko }; 171470e99f38SAlex Zinenko if (!llvm::all_of(loopRanges, [&](Range range) { 171570e99f38SAlex Zinenko return opFoldIsConstantValue(range.offset, 0) && 171670e99f38SAlex Zinenko opFoldIsConstantValue(range.stride, 1); 171770e99f38SAlex Zinenko })) { 171870e99f38SAlex Zinenko return rewriter.notifyMatchFailure( 17195c3ed392SAviad Cohen op, "expected all loop ranges to have zero start and unit stride"); 172070e99f38SAlex Zinenko } 172170e99f38SAlex Zinenko 1722b6f4dd9eSsrcarroll LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter); 17232c58cde0SMahesh Ravishankar 17245c3ed392SAviad Cohen Location loc = op->getLoc(); 17255c3ed392SAviad Cohen if (collapsedOp.hasIndexSemantics()) { 17262c58cde0SMahesh Ravishankar // Collect the loop range of the generic op. 17272c58cde0SMahesh Ravishankar OpBuilder::InsertionGuard g(rewriter); 17285c3ed392SAviad Cohen rewriter.setInsertionPoint(collapsedOp); 172970e99f38SAlex Zinenko SmallVector<Value> loopBound = 17305c3ed392SAviad Cohen llvm::map_to_vector(loopRanges, [&](Range range) { 17314bf84e43SAlexander Belyaev return getValueOrCreateConstantIndexOp(rewriter, loc, range.size); 17325c3ed392SAviad Cohen }); 17335c3ed392SAviad Cohen generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(), 17342c58cde0SMahesh Ravishankar collapsingInfo, loopBound, rewriter); 17352c58cde0SMahesh Ravishankar } 17362c58cde0SMahesh Ravishankar 17372c58cde0SMahesh Ravishankar // Insert expanding reshape for the result to get back the original result 17382c58cde0SMahesh Ravishankar // type. 17392c58cde0SMahesh Ravishankar SmallVector<Value> results; 17405c3ed392SAviad Cohen for (const auto &originalResult : llvm::enumerate(op->getResults())) { 17415c3ed392SAviad Cohen Value collapsedOpResult = collapsedOp->getResult(originalResult.index()); 17422c58cde0SMahesh Ravishankar auto originalResultType = 17435550c821STres Popp cast<ShapedType>(originalResult.value().getType()); 17445550c821STres Popp auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType()); 17452c58cde0SMahesh Ravishankar if (collapsedOpResultType.getRank() != originalResultType.getRank()) { 17462c58cde0SMahesh Ravishankar AffineMap indexingMap = 17475c3ed392SAviad Cohen op.getIndexingMapMatchingResult(originalResult.value()); 17482c58cde0SMahesh Ravishankar SmallVector<ReassociationIndices> reassociation = 17492c58cde0SMahesh Ravishankar getOperandReassociation(indexingMap, collapsingInfo); 175097069a86SGaurav Shukla Value result; 1751d4ae7ee6SAviad Cohen if (isa<MemRefType>(collapsedOpResult.getType())) { 175297069a86SGaurav Shukla MemRefType expandShapeResultType = MemRefType::get( 175397069a86SGaurav Shukla originalResultType.getShape(), originalResultType.getElementType()); 175497069a86SGaurav Shukla result = rewriter.create<memref::ExpandShapeOp>( 175597069a86SGaurav Shukla loc, expandShapeResultType, collapsedOpResult, reassociation); 17562c58cde0SMahesh Ravishankar } else { 175797069a86SGaurav Shukla result = rewriter.create<tensor::ExpandShapeOp>( 17588c0341dfSMehdi Amini loc, originalResultType, collapsedOpResult, reassociation); 17598c0341dfSMehdi Amini } 176097069a86SGaurav Shukla results.push_back(result); 17618c0341dfSMehdi Amini } else { 17622c58cde0SMahesh Ravishankar results.push_back(collapsedOpResult); 17632c58cde0SMahesh Ravishankar } 17642c58cde0SMahesh Ravishankar } 1765b6f4dd9eSsrcarroll return CollapseResult{results, collapsedOp}; 17662c58cde0SMahesh Ravishankar } 17672c58cde0SMahesh Ravishankar 17682c58cde0SMahesh Ravishankar namespace { 17692c58cde0SMahesh Ravishankar 17702c58cde0SMahesh Ravishankar /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by 17712c58cde0SMahesh Ravishankar /// contracting dimensions of the loop. 17722c58cde0SMahesh Ravishankar class FoldWithProducerReshapeOpByCollapsing 17732c58cde0SMahesh Ravishankar : public OpRewritePattern<GenericOp> { 17742c58cde0SMahesh Ravishankar public: 17752291705dSMahesh Ravishankar FoldWithProducerReshapeOpByCollapsing(MLIRContext *context, 17762291705dSMahesh Ravishankar ControlFusionFn foldReshapes, 17772c58cde0SMahesh Ravishankar PatternBenefit benefit = 1) 17782c58cde0SMahesh Ravishankar : OpRewritePattern<GenericOp>(context, benefit), 17792c58cde0SMahesh Ravishankar controlFoldingReshapes(std::move(foldReshapes)) {} 17802c58cde0SMahesh Ravishankar 17812c58cde0SMahesh Ravishankar LogicalResult matchAndRewrite(GenericOp genericOp, 17822c58cde0SMahesh Ravishankar PatternRewriter &rewriter) const override { 1783a7cccb9cSAlexander Belyaev for (OpOperand &opOperand : genericOp->getOpOperands()) { 17842c58cde0SMahesh Ravishankar tensor::ExpandShapeOp reshapeOp = 1785a7cccb9cSAlexander Belyaev opOperand.get().getDefiningOp<tensor::ExpandShapeOp>(); 17862c58cde0SMahesh Ravishankar if (!reshapeOp) 17872c58cde0SMahesh Ravishankar continue; 17882c58cde0SMahesh Ravishankar 1789b40e9013SMahesh Ravishankar SmallVector<ReassociationIndices> collapsableIterationDims = 1790a7cccb9cSAlexander Belyaev getCollapsableIterationSpaceDims(genericOp, &opOperand, 1791b40e9013SMahesh Ravishankar reshapeOp.getReassociationIndices()); 1792b40e9013SMahesh Ravishankar if (collapsableIterationDims.empty() || 1793a7cccb9cSAlexander Belyaev !controlFoldingReshapes(&opOperand)) { 17942c58cde0SMahesh Ravishankar continue; 17952c58cde0SMahesh Ravishankar } 17962c58cde0SMahesh Ravishankar 1797b6f4dd9eSsrcarroll std::optional<CollapseResult> collapseResult = collapseOpIterationDims( 17985c3ed392SAviad Cohen genericOp, collapsableIterationDims, rewriter); 1799b6f4dd9eSsrcarroll if (!collapseResult) { 18002c58cde0SMahesh Ravishankar return rewriter.notifyMatchFailure( 18012c58cde0SMahesh Ravishankar genericOp, "failed to do the fusion by collapsing transformation"); 18022c58cde0SMahesh Ravishankar } 18032c58cde0SMahesh Ravishankar 1804b6f4dd9eSsrcarroll rewriter.replaceOp(genericOp, collapseResult->results); 18052c58cde0SMahesh Ravishankar return success(); 18062c58cde0SMahesh Ravishankar } 18072c58cde0SMahesh Ravishankar return failure(); 18082c58cde0SMahesh Ravishankar } 18092c58cde0SMahesh Ravishankar 18102c58cde0SMahesh Ravishankar private: 18112291705dSMahesh Ravishankar ControlFusionFn controlFoldingReshapes; 18122c58cde0SMahesh Ravishankar }; 181383c65fbcSThomas Raoux 1814c886d66dSMax191 class FoldPadWithProducerReshapeOpByCollapsing 1815c886d66dSMax191 : public OpRewritePattern<tensor::PadOp> { 1816c886d66dSMax191 public: 1817c886d66dSMax191 FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context, 1818c886d66dSMax191 ControlFusionFn foldReshapes, 1819c886d66dSMax191 PatternBenefit benefit = 1) 1820c886d66dSMax191 : OpRewritePattern<tensor::PadOp>(context, benefit), 1821c886d66dSMax191 controlFoldingReshapes(std::move(foldReshapes)) {} 1822c886d66dSMax191 1823c886d66dSMax191 LogicalResult matchAndRewrite(tensor::PadOp padOp, 1824c886d66dSMax191 PatternRewriter &rewriter) const override { 1825c886d66dSMax191 tensor::ExpandShapeOp reshapeOp = 1826c886d66dSMax191 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); 1827c886d66dSMax191 if (!reshapeOp) 1828c886d66dSMax191 return failure(); 1829c886d66dSMax191 if (!reshapeOp->hasOneUse()) 1830c886d66dSMax191 return failure(); 1831c886d66dSMax191 1832c886d66dSMax191 if (!controlFoldingReshapes(&padOp.getSourceMutable())) { 1833c886d66dSMax191 return rewriter.notifyMatchFailure(padOp, 1834c886d66dSMax191 "fusion blocked by control function"); 1835c886d66dSMax191 } 1836c886d66dSMax191 1837c886d66dSMax191 ArrayRef<int64_t> low = padOp.getStaticLow(); 1838c886d66dSMax191 ArrayRef<int64_t> high = padOp.getStaticHigh(); 1839c886d66dSMax191 SmallVector<ReassociationIndices> reassociations = 1840c886d66dSMax191 reshapeOp.getReassociationIndices(); 1841c886d66dSMax191 1842c886d66dSMax191 for (auto reInd : reassociations) { 1843c886d66dSMax191 if (reInd.size() == 1) 1844c886d66dSMax191 continue; 1845c886d66dSMax191 if (llvm::any_of(reInd, [&](int64_t ind) { 1846c886d66dSMax191 return low[ind] != 0 || high[ind] != 0; 1847c886d66dSMax191 })) { 1848c886d66dSMax191 return failure(); 1849c886d66dSMax191 } 1850c886d66dSMax191 } 1851c886d66dSMax191 1852c886d66dSMax191 SmallVector<OpFoldResult> newLow, newHigh; 1853c886d66dSMax191 RankedTensorType collapsedType = reshapeOp.getSrcType(); 1854c886d66dSMax191 RankedTensorType paddedType = padOp.getResultType(); 1855c886d66dSMax191 SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape()); 1856c886d66dSMax191 SmallVector<OpFoldResult> expandedPaddedSizes( 1857c886d66dSMax191 getMixedValues(reshapeOp.getStaticOutputShape(), 1858c886d66dSMax191 reshapeOp.getOutputShape(), rewriter)); 1859c886d66dSMax191 AffineExpr d0, d1, d2; 1860c886d66dSMax191 bindDims(rewriter.getContext(), d0, d1, d2); 1861c886d66dSMax191 auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2}); 1862c886d66dSMax191 Location loc = reshapeOp->getLoc(); 1863c886d66dSMax191 for (auto [idx, reInd] : llvm::enumerate(reassociations)) { 1864c886d66dSMax191 OpFoldResult l = padOp.getMixedLowPad()[reInd[0]]; 1865c886d66dSMax191 OpFoldResult h = padOp.getMixedHighPad()[reInd[0]]; 1866c886d66dSMax191 if (reInd.size() == 1) { 1867c886d66dSMax191 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]]; 1868c886d66dSMax191 OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply( 1869c886d66dSMax191 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]}); 1870c886d66dSMax191 expandedPaddedSizes[reInd[0]] = paddedSize; 1871c886d66dSMax191 } 1872c886d66dSMax191 newLow.push_back(l); 1873c886d66dSMax191 newHigh.push_back(h); 1874c886d66dSMax191 } 1875c886d66dSMax191 1876c886d66dSMax191 RankedTensorType collapsedPaddedType = 1877c886d66dSMax191 paddedType.clone(collapsedPaddedShape); 1878c886d66dSMax191 auto newPadOp = rewriter.create<tensor::PadOp>( 1879c886d66dSMax191 loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh, 1880c886d66dSMax191 padOp.getConstantPaddingValue(), padOp.getNofold()); 1881c886d66dSMax191 1882c886d66dSMax191 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( 1883c886d66dSMax191 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations, 1884c886d66dSMax191 expandedPaddedSizes); 1885c886d66dSMax191 1886c886d66dSMax191 return success(); 1887c886d66dSMax191 } 1888c886d66dSMax191 1889c886d66dSMax191 private: 1890c886d66dSMax191 ControlFusionFn controlFoldingReshapes; 1891c886d66dSMax191 }; 1892c886d66dSMax191 189383c65fbcSThomas Raoux /// Pattern to collapse dimensions. 18945c3ed392SAviad Cohen template <typename LinalgType> 18955c3ed392SAviad Cohen class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> { 189683c65fbcSThomas Raoux public: 189783c65fbcSThomas Raoux CollapseLinalgDimensions(MLIRContext *context, 189883c65fbcSThomas Raoux GetCollapsableDimensionsFn collapseDimensions, 189983c65fbcSThomas Raoux PatternBenefit benefit = 1) 19005c3ed392SAviad Cohen : OpRewritePattern<LinalgType>(context, benefit), 190183c65fbcSThomas Raoux controlCollapseDimension(std::move(collapseDimensions)) {} 190283c65fbcSThomas Raoux 19035c3ed392SAviad Cohen LogicalResult matchAndRewrite(LinalgType op, 190483c65fbcSThomas Raoux PatternRewriter &rewriter) const override { 190583c65fbcSThomas Raoux SmallVector<ReassociationIndices> collapsableIterationDims = 19065c3ed392SAviad Cohen controlCollapseDimension(op); 190783c65fbcSThomas Raoux if (collapsableIterationDims.empty()) 190883c65fbcSThomas Raoux return failure(); 190983c65fbcSThomas Raoux 1910f12639d0SMahesh Ravishankar // Check if the specified list of dimensions to collapse is a valid list. 19115c3ed392SAviad Cohen if (!areDimSequencesPreserved(op.getIndexingMapsArray(), 1912f12639d0SMahesh Ravishankar collapsableIterationDims)) { 1913f12639d0SMahesh Ravishankar return rewriter.notifyMatchFailure( 19145c3ed392SAviad Cohen op, "specified dimensions cannot be collapsed"); 1915f12639d0SMahesh Ravishankar } 1916f12639d0SMahesh Ravishankar 1917b6f4dd9eSsrcarroll std::optional<CollapseResult> collapseResult = 1918b6f4dd9eSsrcarroll collapseOpIterationDims(op, collapsableIterationDims, rewriter); 1919b6f4dd9eSsrcarroll if (!collapseResult) { 19205c3ed392SAviad Cohen return rewriter.notifyMatchFailure(op, "failed to collapse dimensions"); 192183c65fbcSThomas Raoux } 1922b6f4dd9eSsrcarroll rewriter.replaceOp(op, collapseResult->results); 192383c65fbcSThomas Raoux return success(); 192483c65fbcSThomas Raoux } 192583c65fbcSThomas Raoux 192683c65fbcSThomas Raoux private: 192783c65fbcSThomas Raoux GetCollapsableDimensionsFn controlCollapseDimension; 192883c65fbcSThomas Raoux }; 192983c65fbcSThomas Raoux 19302c58cde0SMahesh Ravishankar } // namespace 19312c58cde0SMahesh Ravishankar 19322c58cde0SMahesh Ravishankar //===---------------------------------------------------------------------===// 193332288d37SMahesh Ravishankar // Methods and patterns that fuse constants with linalg.generic operations. 193432288d37SMahesh Ravishankar //===---------------------------------------------------------------------===// 19355994201cSMaheshRavishankar 193632288d37SMahesh Ravishankar namespace { 1937a40a08edSMaheshRavishankar /// Pattern to fold a generic op with a splat constant/scalar constant. Does not 1938a40a08edSMaheshRavishankar /// handle cases where the constant is not single-valued. 19394cd7ff67SLei Zhang class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> { 19405994201cSMaheshRavishankar public: 19412291705dSMahesh Ravishankar FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1) 19422291705dSMahesh Ravishankar : OpRewritePattern<GenericOp>(context, benefit) {} 19435994201cSMaheshRavishankar 19445994201cSMaheshRavishankar LogicalResult matchAndRewrite(GenericOp genericOp, 19455994201cSMaheshRavishankar PatternRewriter &rewriter) const override { 19460a8e3dd4SMatthias Springer if (!genericOp.hasPureTensorSemantics()) 19475994201cSMaheshRavishankar return failure(); 1948b4db15a9SAlexander Belyaev for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { 19495994201cSMaheshRavishankar Operation *def = opOperand->get().getDefiningOp(); 1950e1795322SJeff Niu TypedAttr constantAttr; 1951a40a08edSMaheshRavishankar auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { 1952a40a08edSMaheshRavishankar { 1953a40a08edSMaheshRavishankar DenseElementsAttr splatAttr; 1954a40a08edSMaheshRavishankar if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) && 1955a40a08edSMaheshRavishankar splatAttr.isSplat() && 1956a40a08edSMaheshRavishankar splatAttr.getType().getElementType().isIntOrFloat()) { 1957e1795322SJeff Niu constantAttr = splatAttr.getSplatValue<TypedAttr>(); 1958a40a08edSMaheshRavishankar return true; 1959a40a08edSMaheshRavishankar } 1960a40a08edSMaheshRavishankar } 1961a40a08edSMaheshRavishankar { 1962a40a08edSMaheshRavishankar IntegerAttr intAttr; 1963a40a08edSMaheshRavishankar if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) { 1964a40a08edSMaheshRavishankar constantAttr = intAttr; 1965a40a08edSMaheshRavishankar return true; 1966a40a08edSMaheshRavishankar } 1967a40a08edSMaheshRavishankar } 1968a40a08edSMaheshRavishankar { 1969a40a08edSMaheshRavishankar FloatAttr floatAttr; 1970a40a08edSMaheshRavishankar if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) { 1971a40a08edSMaheshRavishankar constantAttr = floatAttr; 1972a40a08edSMaheshRavishankar return true; 1973a40a08edSMaheshRavishankar } 1974a40a08edSMaheshRavishankar } 1975a40a08edSMaheshRavishankar return false; 1976a40a08edSMaheshRavishankar }; 1977a40a08edSMaheshRavishankar 19785550c821STres Popp auto resultValue = dyn_cast<OpResult>(opOperand->get()); 19792291705dSMahesh Ravishankar if (!def || !resultValue || !isScalarOrSplatConstantOp(def)) 19805994201cSMaheshRavishankar continue; 19815994201cSMaheshRavishankar 19825994201cSMaheshRavishankar // The operands and the indexing_maps of the fused operation the same as 19835994201cSMaheshRavishankar // the operands and indexing_maps of the generic operations with the 19845994201cSMaheshRavishankar // values at the constant index dropped. 19855994201cSMaheshRavishankar SmallVector<AffineMap> fusedIndexMaps; 19865994201cSMaheshRavishankar SmallVector<Value> fusedOperands; 1987b983783dSGeoffrey Martin-Noble SmallVector<Location> fusedLocs{genericOp.getLoc()}; 1988a7cccb9cSAlexander Belyaev fusedIndexMaps.reserve(genericOp->getNumOperands()); 1989b4db15a9SAlexander Belyaev fusedOperands.reserve(genericOp.getNumDpsInputs()); 1990b4db15a9SAlexander Belyaev fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs()); 1991b4db15a9SAlexander Belyaev for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { 19925994201cSMaheshRavishankar if (inputOperand == opOperand) 19935994201cSMaheshRavishankar continue; 1994b983783dSGeoffrey Martin-Noble Value inputValue = inputOperand->get(); 19951227b8abSOleg Shyshkov fusedIndexMaps.push_back( 19961227b8abSOleg Shyshkov genericOp.getMatchingIndexingMap(inputOperand)); 1997b983783dSGeoffrey Martin-Noble fusedOperands.push_back(inputValue); 1998b983783dSGeoffrey Martin-Noble fusedLocs.push_back(inputValue.getLoc()); 19995994201cSMaheshRavishankar } 20000b2197b0SMatthias Springer for (OpOperand &outputOperand : genericOp.getDpsInitsMutable()) 20011227b8abSOleg Shyshkov fusedIndexMaps.push_back( 20020b2197b0SMatthias Springer genericOp.getMatchingIndexingMap(&outputOperand)); 20035994201cSMaheshRavishankar 20045994201cSMaheshRavishankar // Check if the operation shapes to loops map is computable. 200506514c55SIan Wood if (!inversePermutation( 200606514c55SIan Wood concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) { 20075994201cSMaheshRavishankar return rewriter.notifyMatchFailure( 20085994201cSMaheshRavishankar genericOp, "fused op loop bound computation failed"); 20095994201cSMaheshRavishankar } 20105994201cSMaheshRavishankar 20115994201cSMaheshRavishankar // Create a constant scalar value from the splat constant. 201200e3566dSRahul Kayaith Value scalarConstant = 201300e3566dSRahul Kayaith rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr); 20145994201cSMaheshRavishankar 2015a7cccb9cSAlexander Belyaev SmallVector<Value> outputOperands = genericOp.getOutputs(); 20165994201cSMaheshRavishankar auto fusedOp = rewriter.create<GenericOp>( 2017b983783dSGeoffrey Martin-Noble rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), 20185994201cSMaheshRavishankar /*inputs=*/fusedOperands, 20195994201cSMaheshRavishankar /*outputs=*/outputOperands, 20205994201cSMaheshRavishankar rewriter.getAffineMapArrayAttr(fusedIndexMaps), 2021d3b3f765SJacques Pienaar genericOp.getIteratorTypes(), 20225994201cSMaheshRavishankar /*doc=*/nullptr, 20235994201cSMaheshRavishankar /*library_call=*/nullptr); 20245994201cSMaheshRavishankar 20255994201cSMaheshRavishankar // Map the block argument corresponding to the replaced argument with the 20265994201cSMaheshRavishankar // scalar constant. 20275994201cSMaheshRavishankar Region ®ion = genericOp->getRegion(0); 20285994201cSMaheshRavishankar Block &entryBlock = *region.begin(); 20294d67b278SJeff Niu IRMapping mapping; 20305994201cSMaheshRavishankar mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), 20315994201cSMaheshRavishankar scalarConstant); 20325994201cSMaheshRavishankar Region &fusedRegion = fusedOp->getRegion(0); 20335994201cSMaheshRavishankar rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), 20345994201cSMaheshRavishankar mapping); 20355994201cSMaheshRavishankar rewriter.replaceOp(genericOp, fusedOp->getResults()); 20365994201cSMaheshRavishankar return success(); 20375994201cSMaheshRavishankar } 20385994201cSMaheshRavishankar return failure(); 20395994201cSMaheshRavishankar } 20404cd7ff67SLei Zhang }; 20414cd7ff67SLei Zhang 20425994201cSMaheshRavishankar } // namespace 20435994201cSMaheshRavishankar 204432288d37SMahesh Ravishankar //===---------------------------------------------------------------------===// 204532288d37SMahesh Ravishankar // Miscellaneous patterns that help fusion. 204632288d37SMahesh Ravishankar //===---------------------------------------------------------------------===// 20475994201cSMaheshRavishankar 20485994201cSMaheshRavishankar namespace { 204981ca5aa4SMatthias Springer /// Forces `outs` operands of linalg operations to use `tensor.empty` if the 205081ca5aa4SMatthias Springer /// value of the `outs` operand is not used within the op. This is only 20515994201cSMaheshRavishankar /// implemented for `linalg.generic` operations for now, but should hold for all 20525994201cSMaheshRavishankar /// linalg structured ops. 20535994201cSMaheshRavishankar struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { 20545994201cSMaheshRavishankar using OpRewritePattern<GenericOp>::OpRewritePattern; 20555994201cSMaheshRavishankar 20565994201cSMaheshRavishankar LogicalResult matchAndRewrite(GenericOp op, 20575994201cSMaheshRavishankar PatternRewriter &rewriter) const override { 20585fcf907bSMatthias Springer rewriter.startOpModification(op); 20595994201cSMaheshRavishankar bool modifiedOutput = false; 20605994201cSMaheshRavishankar Location loc = op.getLoc(); 20610b2197b0SMatthias Springer for (OpOperand &opOperand : op.getDpsInitsMutable()) { 20620b2197b0SMatthias Springer if (!op.payloadUsesValueFromOperand(&opOperand)) { 20630b2197b0SMatthias Springer Value operandVal = opOperand.get(); 20645550c821STres Popp auto operandType = dyn_cast<RankedTensorType>(operandVal.getType()); 20655994201cSMaheshRavishankar if (!operandType) 20665994201cSMaheshRavishankar continue; 20675994201cSMaheshRavishankar 2068c43e6274STim Harvey // If outs is sparse, leave it to the sparsifier. 2069515c6170SAart Bik if (sparse_tensor::getSparseTensorEncoding(operandVal.getType())) 2070515c6170SAart Bik continue; 2071515c6170SAart Bik 207281ca5aa4SMatthias Springer // If outs is already an `empty` operation, nothing to do. 207381ca5aa4SMatthias Springer auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>(); 20745994201cSMaheshRavishankar if (definingOp) 20755994201cSMaheshRavishankar continue; 20765994201cSMaheshRavishankar modifiedOutput = true; 20776596b0ddSMatthias Springer SmallVector<OpFoldResult> mixedSizes = 20786596b0ddSMatthias Springer tensor::getMixedSizes(rewriter, loc, operandVal); 207981ca5aa4SMatthias Springer Value emptyTensor = rewriter.create<tensor::EmptyOp>( 20806596b0ddSMatthias Springer loc, mixedSizes, operandType.getElementType()); 20810b2197b0SMatthias Springer op->setOperand(opOperand.getOperandNumber(), emptyTensor); 20825994201cSMaheshRavishankar } 20835994201cSMaheshRavishankar } 20845994201cSMaheshRavishankar if (!modifiedOutput) { 20855fcf907bSMatthias Springer rewriter.cancelOpModification(op); 20865994201cSMaheshRavishankar return failure(); 20875994201cSMaheshRavishankar } 20885fcf907bSMatthias Springer rewriter.finalizeOpModification(op); 20895994201cSMaheshRavishankar return success(); 20905994201cSMaheshRavishankar } 20915994201cSMaheshRavishankar }; 20925994201cSMaheshRavishankar 209301055ed1SNirvedh /// Fold linalg.fill into linalg.generic 209401055ed1SNirvedh struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> { 209501055ed1SNirvedh using OpRewritePattern<GenericOp>::OpRewritePattern; 209601055ed1SNirvedh 209701055ed1SNirvedh LogicalResult matchAndRewrite(GenericOp genericOp, 209801055ed1SNirvedh PatternRewriter &rewriter) const override { 20990a8e3dd4SMatthias Springer if (!genericOp.hasPureTensorSemantics()) 210001055ed1SNirvedh return failure(); 210101055ed1SNirvedh bool fillFound = false; 2102d3b3f765SJacques Pienaar Block &payload = genericOp.getRegion().front(); 2103b4db15a9SAlexander Belyaev for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { 210401055ed1SNirvedh if (!genericOp.payloadUsesValueFromOperand(opOperand)) 210501055ed1SNirvedh continue; 210601055ed1SNirvedh FillOp fillOp = opOperand->get().getDefiningOp<FillOp>(); 210701055ed1SNirvedh if (!fillOp) 210801055ed1SNirvedh continue; 210901055ed1SNirvedh fillFound = true; 211004b449e1SPrashant Kumar Value fillVal = fillOp.value(); 211104b449e1SPrashant Kumar auto resultType = 21125550c821STres Popp cast<RankedTensorType>(fillOp.result().getType()).getElementType(); 211304b449e1SPrashant Kumar Value convertedVal = 211404b449e1SPrashant Kumar convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType, 211504b449e1SPrashant Kumar /*isUnsignedCast =*/false); 2116b13a197eSMatthias Springer rewriter.replaceAllUsesWith( 2117b13a197eSMatthias Springer payload.getArgument(opOperand->getOperandNumber()), convertedVal); 211801055ed1SNirvedh } 211901055ed1SNirvedh return success(fillFound); 212001055ed1SNirvedh } 212101055ed1SNirvedh }; 212201055ed1SNirvedh } // namespace 21235994201cSMaheshRavishankar 21245994201cSMaheshRavishankar void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( 21255994201cSMaheshRavishankar RewritePatternSet &patterns, 21262291705dSMahesh Ravishankar const ControlFusionFn &controlFoldingReshapes) { 2127b546f434SMaheshRavishankar patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), 2128b546f434SMaheshRavishankar controlFoldingReshapes); 2129c886d66dSMax191 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(), 2130c886d66dSMax191 controlFoldingReshapes); 21315994201cSMaheshRavishankar patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), 21325994201cSMaheshRavishankar controlFoldingReshapes); 21335994201cSMaheshRavishankar } 21345994201cSMaheshRavishankar 21352c58cde0SMahesh Ravishankar void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( 21362c58cde0SMahesh Ravishankar RewritePatternSet &patterns, 21372291705dSMahesh Ravishankar const ControlFusionFn &controlFoldingReshapes) { 21382c58cde0SMahesh Ravishankar patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(), 21392c58cde0SMahesh Ravishankar controlFoldingReshapes); 2140c886d66dSMax191 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>( 2141c886d66dSMax191 patterns.getContext(), controlFoldingReshapes); 21422c58cde0SMahesh Ravishankar } 21432c58cde0SMahesh Ravishankar 21445994201cSMaheshRavishankar void mlir::linalg::populateElementwiseOpsFusionPatterns( 21452291705dSMahesh Ravishankar RewritePatternSet &patterns, 21462291705dSMahesh Ravishankar const ControlFusionFn &controlElementwiseOpsFusion) { 21475994201cSMaheshRavishankar auto *context = patterns.getContext(); 21482291705dSMahesh Ravishankar patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion); 21492291705dSMahesh Ravishankar patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant, 21502291705dSMahesh Ravishankar RemoveOutsDependency>(context); 2151da8a8e92SMahesh Ravishankar // Add the patterns that clean up dead operands and results. 2152da8a8e92SMahesh Ravishankar populateEraseUnusedOperandsAndResultsPatterns(patterns); 21535994201cSMaheshRavishankar } 21545994201cSMaheshRavishankar 215583c65fbcSThomas Raoux void mlir::linalg::populateCollapseDimensions( 215683c65fbcSThomas Raoux RewritePatternSet &patterns, 215783c65fbcSThomas Raoux const GetCollapsableDimensionsFn &controlCollapseDimensions) { 21585c3ed392SAviad Cohen patterns.add<CollapseLinalgDimensions<linalg::GenericOp>, 21595c3ed392SAviad Cohen CollapseLinalgDimensions<linalg::CopyOp>>( 21605c3ed392SAviad Cohen patterns.getContext(), controlCollapseDimensions); 216183c65fbcSThomas Raoux } 216283c65fbcSThomas Raoux 216332288d37SMahesh Ravishankar //===---------------------------------------------------------------------===// 216432288d37SMahesh Ravishankar // Passes 216532288d37SMahesh Ravishankar //===---------------------------------------------------------------------===// 216632288d37SMahesh Ravishankar 216732288d37SMahesh Ravishankar namespace { 216832288d37SMahesh Ravishankar 216932288d37SMahesh Ravishankar /// Pass that fuses generic ops on tensors. Used only for testing. 21702291705dSMahesh Ravishankar // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the 21712291705dSMahesh Ravishankar // patterns added here heavily depends on the cost function used. Having an 21722291705dSMahesh Ravishankar // opinionated pass of this form is not recommended. Deprecate this pass in 21732291705dSMahesh Ravishankar // favor of test passes that check the functionality of each of the patterns 21742291705dSMahesh Ravishankar // added here individually. 217532288d37SMahesh Ravishankar struct LinalgElementwiseOpFusionPass 21761e98d488SQuinn Dawkins : public impl::LinalgElementwiseOpFusionPassBase< 217767d0d7acSMichele Scuttari LinalgElementwiseOpFusionPass> { 21781e98d488SQuinn Dawkins using impl::LinalgElementwiseOpFusionPassBase< 21791e98d488SQuinn Dawkins LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase; 218032288d37SMahesh Ravishankar void runOnOperation() override { 218132288d37SMahesh Ravishankar Operation *op = getOperation(); 21822291705dSMahesh Ravishankar MLIRContext *context = op->getContext(); 21832291705dSMahesh Ravishankar RewritePatternSet patterns(context); 21842291705dSMahesh Ravishankar 21852291705dSMahesh Ravishankar // Add folding with reshape by expansion patterns. 2186a7bfdc23SMahesh Ravishankar ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) { 2187a7bfdc23SMahesh Ravishankar Operation *producer = fusedOperand->get().getDefiningOp(); 2188a7bfdc23SMahesh Ravishankar return producer && producer->hasOneUse(); 218932288d37SMahesh Ravishankar }; 21902291705dSMahesh Ravishankar 21912291705dSMahesh Ravishankar // Add elementwise op fusion patterns. 21922291705dSMahesh Ravishankar populateElementwiseOpsFusionPatterns(patterns, defaultControlFn); 21930c090dccSMahesh Ravishankar populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn); 2194a95ad2daSIan Wood tensor::populateBubbleUpExpandShapePatterns(patterns); 21952291705dSMahesh Ravishankar 21962291705dSMahesh Ravishankar // General canonicalization patterns. 21974c48f016SMatthias Springer affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context); 21982291705dSMahesh Ravishankar GenericOp::getCanonicalizationPatterns(patterns, context); 21992291705dSMahesh Ravishankar tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 22002291705dSMahesh Ravishankar tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 22012291705dSMahesh Ravishankar context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns( 22022291705dSMahesh Ravishankar patterns); 22032291705dSMahesh Ravishankar 22042291705dSMahesh Ravishankar // Add constant folding patterns. 22052291705dSMahesh Ravishankar populateConstantFoldLinalgOperations(patterns, defaultControlFn); 220632288d37SMahesh Ravishankar 220732288d37SMahesh Ravishankar // Use TopDownTraversal for compile time reasons 220832288d37SMahesh Ravishankar GreedyRewriteConfig grc; 220932288d37SMahesh Ravishankar grc.useTopDownTraversal = true; 221009dfc571SJacques Pienaar (void)applyPatternsGreedily(op, std::move(patterns), grc); 221132288d37SMahesh Ravishankar } 221232288d37SMahesh Ravishankar }; 221332288d37SMahesh Ravishankar 221432288d37SMahesh Ravishankar } // namespace 2215