12b0c8546SMaheshRavishankar //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===// 22b0c8546SMaheshRavishankar // 32b0c8546SMaheshRavishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42b0c8546SMaheshRavishankar // See https://llvm.org/LICENSE.txt for license information. 52b0c8546SMaheshRavishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62b0c8546SMaheshRavishankar // 72b0c8546SMaheshRavishankar //===----------------------------------------------------------------------===// 82b0c8546SMaheshRavishankar // 92b0c8546SMaheshRavishankar // This file implements patterns/pass to remove usage of unit-extent dimensions 102b0c8546SMaheshRavishankar // to specify broadcasting in favor of more canonical representation of the 112b0c8546SMaheshRavishankar // computation 122b0c8546SMaheshRavishankar // 132b0c8546SMaheshRavishankar //===----------------------------------------------------------------------===// 142b0c8546SMaheshRavishankar 1567d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h" 1667d0d7acSMichele Scuttari 1767d0d7acSMichele Scuttari #include "mlir/Dialect/Affine/IR/AffineOps.h" 18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 19b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h" 20ea069aebSMaheshRavishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 212b0c8546SMaheshRavishankar #include "mlir/Dialect/Linalg/Utils/Utils.h" 22faafd26cSQuentin Colombet #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 23060208b4SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 24f6fb0a4fSAlexander Belyaev #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 259b16d9d2SHanhan Wang #include "mlir/Dialect/Tensor/Utils/Utils.h" 2697069a86SGaurav Shukla #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" 272b0c8546SMaheshRavishankar #include "mlir/IR/AffineExpr.h" 282b0c8546SMaheshRavishankar #include "mlir/IR/AffineMap.h" 296c7be417STres Popp #include "mlir/IR/BuiltinTypes.h" 302b0c8546SMaheshRavishankar #include "mlir/Transforms/FoldUtils.h" 31b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 329b16d9d2SHanhan Wang #include "llvm/ADT/SetVector.h" 332b0c8546SMaheshRavishankar #include "llvm/Support/CommandLine.h" 342b0c8546SMaheshRavishankar #include "llvm/Support/Debug.h" 352b0c8546SMaheshRavishankar 3667d0d7acSMichele Scuttari namespace mlir { 371e98d488SQuinn Dawkins #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS 3867d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h.inc" 3967d0d7acSMichele Scuttari } // namespace mlir 4067d0d7acSMichele Scuttari 412b0c8546SMaheshRavishankar #define DEBUG_TYPE "linalg-drop-unit-dims" 422b0c8546SMaheshRavishankar 432b0c8546SMaheshRavishankar using namespace mlir; 442b0c8546SMaheshRavishankar using namespace mlir::linalg; 452b0c8546SMaheshRavishankar 46e07149c9SMatthias Springer namespace { 47d2b070d3SMatthias Springer /// Pattern to move init operands to ins when all the loops are parallel and 489b16d9d2SHanhan Wang /// blockArgument corresponding to init is used in the region. This is a fix-up 499b16d9d2SHanhan Wang /// when unit reduction dimensions are all folded away. In this context, it 509b16d9d2SHanhan Wang /// becomes a elementwise generic op. E.g., it converts 519b16d9d2SHanhan Wang /// 529b16d9d2SHanhan Wang /// %0 = tensor.empty() : tensor<1x1xf32> 539b16d9d2SHanhan Wang /// %1 = linalg.fill 549b16d9d2SHanhan Wang /// ins(%cst : f32) 559b16d9d2SHanhan Wang /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> 569b16d9d2SHanhan Wang /// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, 579b16d9d2SHanhan Wang /// affine_map<(d0) -> (0, d0)>], 589b16d9d2SHanhan Wang /// iterator_types = ["parallel"]} 599b16d9d2SHanhan Wang /// ins(%arg0 : tensor<1x?x1x1xf32>) 609b16d9d2SHanhan Wang /// outs(%1 : tensor<1x1xf32>) { 619b16d9d2SHanhan Wang /// ^bb0(%in: f32, %out: f32): 629b16d9d2SHanhan Wang /// %3 = arith.addf %in, %out : f32 639b16d9d2SHanhan Wang /// linalg.yield %3 : f32 649b16d9d2SHanhan Wang /// } -> tensor<1x1xf32> 659b16d9d2SHanhan Wang /// 669b16d9d2SHanhan Wang /// into 679b16d9d2SHanhan Wang /// 689b16d9d2SHanhan Wang /// %0 = tensor.empty() : tensor<1x1xf32> 699b16d9d2SHanhan Wang /// %1 = linalg.fill 709b16d9d2SHanhan Wang /// ins(%cst : f32) 719b16d9d2SHanhan Wang /// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> 729b16d9d2SHanhan Wang /// %2 = tensor.empty() : tensor<1x1xf32> 739b16d9d2SHanhan Wang /// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, 749b16d9d2SHanhan Wang /// affine_map<(d0) -> (0, d0)>, 759b16d9d2SHanhan Wang /// affine_map<(d0) -> (0, d0)>], 769b16d9d2SHanhan Wang /// iterator_types = ["parallel"]} 779b16d9d2SHanhan Wang /// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>) 789b16d9d2SHanhan Wang /// outs(%2 : tensor<1x1xf32>) { 799b16d9d2SHanhan Wang /// ^bb0(%in: f32, %in_0: f32, %out: f32): 809b16d9d2SHanhan Wang /// %4 = arith.addf %in, %in_0 : f32 819b16d9d2SHanhan Wang /// linalg.yield %4 : f32 829b16d9d2SHanhan Wang /// } -> tensor<1x1xf32> 83d2b070d3SMatthias Springer struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> { 849b16d9d2SHanhan Wang using OpRewritePattern<GenericOp>::OpRewritePattern; 859b16d9d2SHanhan Wang LogicalResult matchAndRewrite(GenericOp genericOp, 869b16d9d2SHanhan Wang PatternRewriter &rewriter) const override { 870a8e3dd4SMatthias Springer if (!genericOp.hasPureTensorSemantics()) 889b16d9d2SHanhan Wang return failure(); 899b16d9d2SHanhan Wang if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) 909b16d9d2SHanhan Wang return failure(); 919b16d9d2SHanhan Wang 920b2197b0SMatthias Springer auto outputOperands = genericOp.getDpsInitsMutable(); 939b16d9d2SHanhan Wang SetVector<OpOperand *> candidates; 940b2197b0SMatthias Springer for (OpOperand &op : outputOperands) { 950b2197b0SMatthias Springer if (genericOp.getMatchingBlockArgument(&op).use_empty()) 969b16d9d2SHanhan Wang continue; 970b2197b0SMatthias Springer candidates.insert(&op); 989b16d9d2SHanhan Wang } 999b16d9d2SHanhan Wang 1009b16d9d2SHanhan Wang if (candidates.empty()) 1019b16d9d2SHanhan Wang return failure(); 1029b16d9d2SHanhan Wang 1039b16d9d2SHanhan Wang // Compute the modified indexing maps. 1049b16d9d2SHanhan Wang int64_t origNumInput = genericOp.getNumDpsInputs(); 1050b2197b0SMatthias Springer SmallVector<Value> newInputOperands = genericOp.getDpsInputs(); 1069b16d9d2SHanhan Wang SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); 1079b16d9d2SHanhan Wang SmallVector<AffineMap> newIndexingMaps; 1089b16d9d2SHanhan Wang newIndexingMaps.append(indexingMaps.begin(), 1099b16d9d2SHanhan Wang std::next(indexingMaps.begin(), origNumInput)); 1109b16d9d2SHanhan Wang for (OpOperand *op : candidates) { 1119b16d9d2SHanhan Wang newInputOperands.push_back(op->get()); 1129b16d9d2SHanhan Wang newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op)); 1139b16d9d2SHanhan Wang } 1149b16d9d2SHanhan Wang newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput), 1159b16d9d2SHanhan Wang indexingMaps.end()); 1169b16d9d2SHanhan Wang 1179b16d9d2SHanhan Wang Location loc = genericOp.getLoc(); 1180b2197b0SMatthias Springer SmallVector<Value> newOutputOperands = 1190b2197b0SMatthias Springer llvm::to_vector(genericOp.getDpsInits()); 1209b16d9d2SHanhan Wang for (OpOperand *op : candidates) { 1219b16d9d2SHanhan Wang OpBuilder::InsertionGuard guard(rewriter); 1229b16d9d2SHanhan Wang rewriter.setInsertionPointAfterValue(op->get()); 1235550c821STres Popp auto elemType = cast<ShapedType>(op->get().getType()).getElementType(); 1249b16d9d2SHanhan Wang auto empty = rewriter.create<tensor::EmptyOp>( 1256596b0ddSMatthias Springer loc, tensor::getMixedSizes(rewriter, loc, op->get()), elemType); 1269b16d9d2SHanhan Wang 1270b2197b0SMatthias Springer unsigned start = genericOp.getDpsInits().getBeginOperandIndex(); 1289b16d9d2SHanhan Wang newOutputOperands[op->getOperandNumber() - start] = empty.getResult(); 1299b16d9d2SHanhan Wang } 1309b16d9d2SHanhan Wang 1319b16d9d2SHanhan Wang auto newOp = rewriter.create<GenericOp>( 1329b16d9d2SHanhan Wang loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands, 1339b16d9d2SHanhan Wang newIndexingMaps, genericOp.getIteratorTypesArray(), 1349b16d9d2SHanhan Wang /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); 1359b16d9d2SHanhan Wang 1369b16d9d2SHanhan Wang OpBuilder::InsertionGuard guard(rewriter); 13791d5653eSMatthias Springer Region ®ion = newOp.getRegion(); 13891d5653eSMatthias Springer Block *block = rewriter.createBlock(®ion); 13991d5653eSMatthias Springer IRMapping mapper; 1409b16d9d2SHanhan Wang for (auto bbarg : genericOp.getRegionInputArgs()) 1419b16d9d2SHanhan Wang mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); 1429b16d9d2SHanhan Wang 1439b16d9d2SHanhan Wang for (OpOperand *op : candidates) { 1449b16d9d2SHanhan Wang BlockArgument bbarg = genericOp.getMatchingBlockArgument(op); 1459b16d9d2SHanhan Wang mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); 1469b16d9d2SHanhan Wang } 1479b16d9d2SHanhan Wang 1480b2197b0SMatthias Springer for (OpOperand &op : outputOperands) { 1490b2197b0SMatthias Springer BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op); 1500b2197b0SMatthias Springer if (candidates.count(&op)) 1519b16d9d2SHanhan Wang block->addArgument(bbarg.getType(), loc); 1529b16d9d2SHanhan Wang else 1539b16d9d2SHanhan Wang mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); 1549b16d9d2SHanhan Wang } 1559b16d9d2SHanhan Wang 1569b16d9d2SHanhan Wang for (auto &op : genericOp.getBody()->getOperations()) { 1579b16d9d2SHanhan Wang rewriter.clone(op, mapper); 1589b16d9d2SHanhan Wang } 1599b16d9d2SHanhan Wang rewriter.replaceOp(genericOp, newOp.getResults()); 1609b16d9d2SHanhan Wang 1619b16d9d2SHanhan Wang return success(); 1629b16d9d2SHanhan Wang } 1639b16d9d2SHanhan Wang }; 1642b0c8546SMaheshRavishankar } // namespace 1652b0c8546SMaheshRavishankar 16667399932SMahesh Ravishankar //===---------------------------------------------------------------------===// 16767399932SMahesh Ravishankar // Drop loops that are unit-extents within Linalg operations. 16867399932SMahesh Ravishankar //===---------------------------------------------------------------------===// 1692b0c8546SMaheshRavishankar 17067399932SMahesh Ravishankar /// Implements a pass that canonicalizes the uses of unit-extent dimensions for 17167399932SMahesh Ravishankar /// broadcasting. For example, 17267399932SMahesh Ravishankar /// 17367399932SMahesh Ravishankar /// ```mlir 17467399932SMahesh Ravishankar /// #accesses = [ 17567399932SMahesh Ravishankar /// affine_map<(d0, d1) -> (0, d1)>, 17667399932SMahesh Ravishankar /// affine_map<(d0, d1) -> (d0, 0)>, 17767399932SMahesh Ravishankar /// affine_map<(d0, d1) -> (d0, d1)> 17867399932SMahesh Ravishankar /// ] 17967399932SMahesh Ravishankar /// 18067399932SMahesh Ravishankar /// #trait = { 18167399932SMahesh Ravishankar /// indexing_maps = #accesses, 18267399932SMahesh Ravishankar /// iterator_types = ["parallel", "parallel"], 18367399932SMahesh Ravishankar /// library_call = "some_external_fn" 18467399932SMahesh Ravishankar /// } 18567399932SMahesh Ravishankar /// 18667399932SMahesh Ravishankar /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 18767399932SMahesh Ravishankar /// tensor<5x5xf32> 18867399932SMahesh Ravishankar /// { 18967399932SMahesh Ravishankar /// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] : 19067399932SMahesh Ravishankar /// tensor<5xf32> into tensor<1x5xf32> 19167399932SMahesh Ravishankar /// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] : 19267399932SMahesh Ravishankar /// tensor<5xf32> into tensor<5x1xf32> 19367399932SMahesh Ravishankar /// %2 = linalg.generic #trait %0, %1 { 19467399932SMahesh Ravishankar /// ^bb0(%arg2: f32, %arg3: f32): 19567399932SMahesh Ravishankar /// %3 = arith.addf %arg2, %arg3 : f32 19667399932SMahesh Ravishankar /// linalg.yield %3 : f32 19767399932SMahesh Ravishankar /// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32> 19867399932SMahesh Ravishankar /// return %2 : tensor<5x5xf32> 19967399932SMahesh Ravishankar /// } 20067399932SMahesh Ravishankar /// 20167399932SMahesh Ravishankar /// would canonicalize to 20267399932SMahesh Ravishankar /// 20367399932SMahesh Ravishankar /// ```mlir 20467399932SMahesh Ravishankar /// #accesses = [ 20567399932SMahesh Ravishankar /// affine_map<(d0, d1) -> (d1)>, 20667399932SMahesh Ravishankar /// affine_map<(d0, d1) -> (d0)>, 20767399932SMahesh Ravishankar /// affine_map<(d0, d1) -> (d0, d1)> 20867399932SMahesh Ravishankar /// ] 20967399932SMahesh Ravishankar /// 21067399932SMahesh Ravishankar /// #trait = { 21167399932SMahesh Ravishankar /// indexing_maps = #accesses, 21267399932SMahesh Ravishankar /// iterator_types = ["parallel", "parallel"], 21367399932SMahesh Ravishankar /// library_call = "some_external_fn" 21467399932SMahesh Ravishankar /// } 21567399932SMahesh Ravishankar /// 21667399932SMahesh Ravishankar /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> 21767399932SMahesh Ravishankar /// tensor<5x5xf32> 21867399932SMahesh Ravishankar /// { 21967399932SMahesh Ravishankar /// %0 = linalg.generic #trait %arg0, %arg1 { 22067399932SMahesh Ravishankar /// ^bb0(%arg2: f32, %arg3: f32): 22167399932SMahesh Ravishankar /// %3 = arith.addf %arg2, %arg3 : f32 22267399932SMahesh Ravishankar /// linalg.yield %3 : f32 22367399932SMahesh Ravishankar /// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32> 22467399932SMahesh Ravishankar /// return %0 : tensor<5x5xf32> 22567399932SMahesh Ravishankar /// } 2262b0c8546SMaheshRavishankar 22767399932SMahesh Ravishankar /// Update the index accesses of linalg operations having index semantics. 22867399932SMahesh Ravishankar static void 22967399932SMahesh Ravishankar replaceUnitDimIndexOps(GenericOp genericOp, 23067399932SMahesh Ravishankar const llvm::SmallDenseSet<unsigned> &unitDims, 23167399932SMahesh Ravishankar RewriterBase &rewriter) { 23267399932SMahesh Ravishankar for (IndexOp indexOp : 23367399932SMahesh Ravishankar llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) { 23467399932SMahesh Ravishankar OpBuilder::InsertionGuard guard(rewriter); 23567399932SMahesh Ravishankar rewriter.setInsertionPoint(indexOp); 23667399932SMahesh Ravishankar if (unitDims.count(indexOp.getDim()) != 0) { 23767399932SMahesh Ravishankar rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0); 23867399932SMahesh Ravishankar } else { 23967399932SMahesh Ravishankar // Update the dimension of the index operation if needed. 24067399932SMahesh Ravishankar unsigned droppedDims = llvm::count_if( 24167399932SMahesh Ravishankar unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); }); 24267399932SMahesh Ravishankar if (droppedDims != 0) 24367399932SMahesh Ravishankar rewriter.replaceOpWithNewOp<IndexOp>(indexOp, 24467399932SMahesh Ravishankar indexOp.getDim() - droppedDims); 24567399932SMahesh Ravishankar } 24667399932SMahesh Ravishankar } 24744485fcdSTres Popp } 24844485fcdSTres Popp 24967399932SMahesh Ravishankar /// Expand the given `value` so that the type matches the type of `origDest`. 25067399932SMahesh Ravishankar /// The `reassociation` is used when `rankReductionStrategy` is set to 25167399932SMahesh Ravishankar /// `RankReductionStrategy::ReassociativeReshape`. 25267399932SMahesh Ravishankar static Value 25367399932SMahesh Ravishankar expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, 25467399932SMahesh Ravishankar ArrayRef<ReassociationIndices> reassociation, 25567399932SMahesh Ravishankar ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) { 256e07149c9SMatthias Springer // There are no results for memref outputs. 25767399932SMahesh Ravishankar auto origResultType = cast<RankedTensorType>(origDest.getType()); 25867399932SMahesh Ravishankar if (rankReductionStrategy == 25967399932SMahesh Ravishankar ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { 260e07149c9SMatthias Springer unsigned rank = origResultType.getRank(); 261e07149c9SMatthias Springer SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); 262e07149c9SMatthias Springer SmallVector<OpFoldResult> sizes = 26367399932SMahesh Ravishankar tensor::getMixedSizes(rewriter, loc, origDest); 264e07149c9SMatthias Springer SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 265e07149c9SMatthias Springer return rewriter.createOrFold<tensor::InsertSliceOp>( 26667399932SMahesh Ravishankar loc, result, origDest, offsets, sizes, strides); 2676c7be417STres Popp } 2686c7be417STres Popp 269e07149c9SMatthias Springer assert(rankReductionStrategy == 27067399932SMahesh Ravishankar ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && 271e07149c9SMatthias Springer "unknown rank reduction strategy"); 27297069a86SGaurav Shukla return rewriter 27397069a86SGaurav Shukla .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation) 27497069a86SGaurav Shukla .getResult(); 2756c7be417STres Popp } 276e07149c9SMatthias Springer 27767399932SMahesh Ravishankar /// Collapse the given `value` so that the type matches the type of 27867399932SMahesh Ravishankar /// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is 27967399932SMahesh Ravishankar /// set to `RankReductionStrategy::ReassociativeReshape`. 28067399932SMahesh Ravishankar static Value collapseValue( 28167399932SMahesh Ravishankar RewriterBase &rewriter, Location loc, Value operand, 28267399932SMahesh Ravishankar ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation, 28367399932SMahesh Ravishankar ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) { 2845550c821STres Popp if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) { 28567399932SMahesh Ravishankar if (rankReductionStrategy == 28667399932SMahesh Ravishankar ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { 287e07149c9SMatthias Springer FailureOr<Value> rankReducingExtract = 288e07149c9SMatthias Springer memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, 289e07149c9SMatthias Springer targetShape); 290e07149c9SMatthias Springer assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); 291e07149c9SMatthias Springer return *rankReducingExtract; 2926c7be417STres Popp } 293e07149c9SMatthias Springer 29467399932SMahesh Ravishankar assert( 29567399932SMahesh Ravishankar rankReductionStrategy == 29667399932SMahesh Ravishankar ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && 297e07149c9SMatthias Springer "unknown rank reduction strategy"); 298e07149c9SMatthias Springer MemRefLayoutAttrInterface layout; 29967399932SMahesh Ravishankar auto targetType = MemRefType::get(targetShape, memrefType.getElementType(), 30067399932SMahesh Ravishankar layout, memrefType.getMemorySpace()); 301e07149c9SMatthias Springer return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand, 302e07149c9SMatthias Springer reassociation); 303e07149c9SMatthias Springer } 3045550c821STres Popp if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) { 30567399932SMahesh Ravishankar if (rankReductionStrategy == 30667399932SMahesh Ravishankar ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { 307e07149c9SMatthias Springer FailureOr<Value> rankReducingExtract = 308e07149c9SMatthias Springer tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand, 309e07149c9SMatthias Springer targetShape); 310e07149c9SMatthias Springer assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); 311e07149c9SMatthias Springer return *rankReducingExtract; 312e07149c9SMatthias Springer } 313e07149c9SMatthias Springer 31467399932SMahesh Ravishankar assert( 31567399932SMahesh Ravishankar rankReductionStrategy == 31667399932SMahesh Ravishankar ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && 317e07149c9SMatthias Springer "unknown rank reduction strategy"); 318e07149c9SMatthias Springer auto targetType = 319e07149c9SMatthias Springer RankedTensorType::get(targetShape, tensorType.getElementType()); 320e07149c9SMatthias Springer return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand, 321e07149c9SMatthias Springer reassociation); 322e07149c9SMatthias Springer } 323e07149c9SMatthias Springer llvm_unreachable("unsupported operand type"); 324e07149c9SMatthias Springer } 3256c7be417STres Popp 32667399932SMahesh Ravishankar /// Compute the modified metadata for an operands of operation 32767399932SMahesh Ravishankar /// whose unit dims are being dropped. Return the new indexing map 32867399932SMahesh Ravishankar /// to use, the shape of the operand in the replacement op 32967399932SMahesh Ravishankar /// and the `reassocation` to use to go from original operand shape 33067399932SMahesh Ravishankar /// to modified operand shape. 33167399932SMahesh Ravishankar struct UnitExtentReplacementInfo { 33267399932SMahesh Ravishankar AffineMap indexMap; 33367399932SMahesh Ravishankar SmallVector<ReassociationIndices> reassociation; 33467399932SMahesh Ravishankar SmallVector<int64_t> targetShape; 33567399932SMahesh Ravishankar }; 33667399932SMahesh Ravishankar static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata( 33767399932SMahesh Ravishankar MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, 33867399932SMahesh Ravishankar llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap, 33967399932SMahesh Ravishankar ArrayRef<AffineExpr> dimReplacements) { 34067399932SMahesh Ravishankar UnitExtentReplacementInfo info; 34167399932SMahesh Ravishankar ReassociationIndices reassociationGroup; 34267399932SMahesh Ravishankar SmallVector<AffineExpr> newIndexExprs; 34367399932SMahesh Ravishankar AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); 34467399932SMahesh Ravishankar ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand); 34567399932SMahesh Ravishankar ArrayRef<AffineExpr> exprs = indexingMap.getResults(); 3462b0c8546SMaheshRavishankar 34767399932SMahesh Ravishankar auto isUnitDim = [&](unsigned dim) { 3481609f1c2Slong.chen if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) { 34967399932SMahesh Ravishankar unsigned oldPosition = dimExpr.getPosition(); 35044861c7aSSayan Saha return !oldDimsToNewDimsMap.count(oldPosition) && 35144861c7aSSayan Saha (operandShape[dim] == 1); 35267399932SMahesh Ravishankar } 35367399932SMahesh Ravishankar // Handle the other case where the shape is 1, and is accessed using a 35467399932SMahesh Ravishankar // constant 0. 35567399932SMahesh Ravishankar if (operandShape[dim] == 1) { 3561609f1c2Slong.chen auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]); 35767399932SMahesh Ravishankar return constAffineExpr && constAffineExpr.getValue() == 0; 35867399932SMahesh Ravishankar } 35967399932SMahesh Ravishankar return false; 36067399932SMahesh Ravishankar }; 36167399932SMahesh Ravishankar 3622efe88b5SMahesh Ravishankar unsigned dim = 0; 36367399932SMahesh Ravishankar while (dim < operandShape.size() && isUnitDim(dim)) 36467399932SMahesh Ravishankar reassociationGroup.push_back(dim++); 36567399932SMahesh Ravishankar while (dim < operandShape.size()) { 36667399932SMahesh Ravishankar assert(!isUnitDim(dim) && "expected non unit-extent"); 36767399932SMahesh Ravishankar reassociationGroup.push_back(dim); 36867399932SMahesh Ravishankar AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements); 36967399932SMahesh Ravishankar newIndexExprs.push_back(newExpr); 37067399932SMahesh Ravishankar info.targetShape.push_back(operandShape[dim]); 37167399932SMahesh Ravishankar ++dim; 37267399932SMahesh Ravishankar // Fold all following dimensions that are unit-extent. 37367399932SMahesh Ravishankar while (dim < operandShape.size() && isUnitDim(dim)) { 37467399932SMahesh Ravishankar reassociationGroup.push_back(dim++); 37567399932SMahesh Ravishankar } 37667399932SMahesh Ravishankar info.reassociation.push_back(reassociationGroup); 37767399932SMahesh Ravishankar reassociationGroup.clear(); 37867399932SMahesh Ravishankar } 37967399932SMahesh Ravishankar info.indexMap = 38067399932SMahesh Ravishankar AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(), 38167399932SMahesh Ravishankar newIndexExprs, context); 38267399932SMahesh Ravishankar return info; 38367399932SMahesh Ravishankar } 38467399932SMahesh Ravishankar 3854dbaef6dSMaheshRavishankar FailureOr<DropUnitDimsResult> 3864dbaef6dSMaheshRavishankar linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, 38767399932SMahesh Ravishankar const ControlDropUnitDims &options) { 38867399932SMahesh Ravishankar SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); 38967399932SMahesh Ravishankar if (indexingMaps.empty()) 39067399932SMahesh Ravishankar return failure(); 39167399932SMahesh Ravishankar 39267399932SMahesh Ravishankar // 1. Check if any of the iteration dimensions are unit-trip count. They will 39367399932SMahesh Ravishankar // end up being unit-trip count if they are used to index into a unit-dim 39467399932SMahesh Ravishankar // tensor/memref. 39506514c55SIan Wood AffineMap invertedMap = 39606514c55SIan Wood inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext())); 39767399932SMahesh Ravishankar if (!invertedMap) { 39867399932SMahesh Ravishankar return rewriter.notifyMatchFailure(genericOp, 39967399932SMahesh Ravishankar "invalid indexing maps for operation"); 40067399932SMahesh Ravishankar } 40167399932SMahesh Ravishankar SmallVector<int64_t> dims = genericOp.getStaticShape(); 40267399932SMahesh Ravishankar 40367399932SMahesh Ravishankar // 1a. Get the allowed list of dimensions to drop from the `options`. 40467399932SMahesh Ravishankar SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp); 40567399932SMahesh Ravishankar if (allowedUnitDims.empty()) { 40667399932SMahesh Ravishankar return rewriter.notifyMatchFailure( 40767399932SMahesh Ravishankar genericOp, "control function returns no allowed unit dims to prune"); 40867399932SMahesh Ravishankar } 40967399932SMahesh Ravishankar llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(), 41067399932SMahesh Ravishankar allowedUnitDims.end()); 41167399932SMahesh Ravishankar llvm::SmallDenseSet<unsigned> unitDims; 41267399932SMahesh Ravishankar for (const auto &expr : enumerate(invertedMap.getResults())) { 4131609f1c2Slong.chen if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) { 41467399932SMahesh Ravishankar if (dims[dimExpr.getPosition()] == 1 && 41567399932SMahesh Ravishankar unitDimsFilter.count(expr.index())) 41667399932SMahesh Ravishankar unitDims.insert(expr.index()); 41767399932SMahesh Ravishankar } 41867399932SMahesh Ravishankar } 41967399932SMahesh Ravishankar 42067399932SMahesh Ravishankar // 2. Compute the iterator types of the modified op by dropping the one-trip 42167399932SMahesh Ravishankar // count loops. 42267399932SMahesh Ravishankar SmallVector<utils::IteratorType> newIteratorTypes; 42367399932SMahesh Ravishankar llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap; 42467399932SMahesh Ravishankar SmallVector<AffineExpr> dimReplacements; 42567399932SMahesh Ravishankar unsigned newDims = 0; 42667399932SMahesh Ravishankar for (auto [index, attr] : 42767399932SMahesh Ravishankar llvm::enumerate(genericOp.getIteratorTypesArray())) { 42867399932SMahesh Ravishankar if (unitDims.count(index)) { 42967399932SMahesh Ravishankar dimReplacements.push_back( 43067399932SMahesh Ravishankar getAffineConstantExpr(0, rewriter.getContext())); 43167399932SMahesh Ravishankar } else { 43267399932SMahesh Ravishankar newIteratorTypes.push_back(attr); 43367399932SMahesh Ravishankar oldDimToNewDimMap[index] = newDims; 43467399932SMahesh Ravishankar dimReplacements.push_back( 43567399932SMahesh Ravishankar getAffineDimExpr(newDims, rewriter.getContext())); 43667399932SMahesh Ravishankar newDims++; 43767399932SMahesh Ravishankar } 43867399932SMahesh Ravishankar } 43967399932SMahesh Ravishankar 44067399932SMahesh Ravishankar // 3. For each of the operands, find the 44167399932SMahesh Ravishankar // - modified affine map to use. 44267399932SMahesh Ravishankar // - shape of the operands after the unit-dims are dropped. 44367399932SMahesh Ravishankar // - the reassociation indices used to convert from the original 44467399932SMahesh Ravishankar // operand type to modified operand (needed only when using reshapes 44567399932SMahesh Ravishankar // for rank reduction strategy) 44667399932SMahesh Ravishankar // Note that the indexing maps might need changing even if there are no 44767399932SMahesh Ravishankar // unit dimensions that are dropped to handle cases where `0` is used to 44867399932SMahesh Ravishankar // access a unit-extent tensor. Consider moving this out of this specific 44967399932SMahesh Ravishankar // transformation as a stand-alone transformation. Kept here right now due 45067399932SMahesh Ravishankar // to legacy. 451f6b4e081STobias Gysi SmallVector<AffineMap> newIndexingMaps; 452e07149c9SMatthias Springer SmallVector<SmallVector<ReassociationIndices>> reassociations; 453e07149c9SMatthias Springer SmallVector<SmallVector<int64_t>> targetShapes; 454e07149c9SMatthias Springer SmallVector<bool> collapsed; 45567399932SMahesh Ravishankar auto hasCollapsibleType = [](OpOperand &operand) { 45667399932SMahesh Ravishankar Type operandType = operand.get().getType(); 45767399932SMahesh Ravishankar if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) { 45867399932SMahesh Ravishankar return memrefOperandType.getLayout().isIdentity(); 459f19f2139SMehdi Amini } 460f19f2139SMehdi Amini if (auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) { 46167399932SMahesh Ravishankar return tensorOperandType.getEncoding() == nullptr; 46244485fcdSTres Popp } 46367399932SMahesh Ravishankar return false; 46467399932SMahesh Ravishankar }; 46567399932SMahesh Ravishankar for (OpOperand &opOperand : genericOp->getOpOperands()) { 46667399932SMahesh Ravishankar auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand); 46767399932SMahesh Ravishankar ArrayRef<int64_t> shape = genericOp.getShape(&opOperand); 46867399932SMahesh Ravishankar if (!hasCollapsibleType(opOperand)) { 46967399932SMahesh Ravishankar AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols( 47067399932SMahesh Ravishankar dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0); 47167399932SMahesh Ravishankar newIndexingMaps.push_back(newIndexingMap); 47267399932SMahesh Ravishankar targetShapes.push_back(llvm::to_vector(shape)); 47367399932SMahesh Ravishankar collapsed.push_back(false); 47467399932SMahesh Ravishankar reassociations.push_back({}); 47567399932SMahesh Ravishankar continue; 47667399932SMahesh Ravishankar } 47767399932SMahesh Ravishankar auto replacementInfo = dropUnitExtentFromOperandMetadata( 47867399932SMahesh Ravishankar rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap, 47967399932SMahesh Ravishankar dimReplacements); 48067399932SMahesh Ravishankar reassociations.push_back(replacementInfo.reassociation); 48167399932SMahesh Ravishankar newIndexingMaps.push_back(replacementInfo.indexMap); 48267399932SMahesh Ravishankar targetShapes.push_back(replacementInfo.targetShape); 48367399932SMahesh Ravishankar collapsed.push_back(!(replacementInfo.indexMap.getNumResults() == 48467399932SMahesh Ravishankar indexingMap.getNumResults())); 4852b0c8546SMaheshRavishankar } 4862b0c8546SMaheshRavishankar 487e07149c9SMatthias Springer // Abort if the indexing maps of the result operation are not invertible 488e07149c9SMatthias Springer // (i.e. not legal) or if no dimension was reduced. 48967399932SMahesh Ravishankar if (newIndexingMaps == indexingMaps || 49006514c55SIan Wood !inversePermutation( 49106514c55SIan Wood concatAffineMaps(newIndexingMaps, rewriter.getContext()))) 4922b0c8546SMaheshRavishankar return failure(); 4932b0c8546SMaheshRavishankar 49467399932SMahesh Ravishankar Location loc = genericOp.getLoc(); 49567399932SMahesh Ravishankar // 4. For each of the operands, collapse the operand to convert 49667399932SMahesh Ravishankar // from original shape to shape in the modified operation if needed, 49767399932SMahesh Ravishankar // either through use of reshapes or rank-reducing slices as 49867399932SMahesh Ravishankar // specified in `options`. 499e07149c9SMatthias Springer SmallVector<Value> newOperands; 500e07149c9SMatthias Springer for (OpOperand &opOperand : genericOp->getOpOperands()) { 501e07149c9SMatthias Springer int64_t idx = opOperand.getOperandNumber(); 502e07149c9SMatthias Springer if (!collapsed[idx]) { 503e07149c9SMatthias Springer newOperands.push_back(opOperand.get()); 504e07149c9SMatthias Springer continue; 5052b0c8546SMaheshRavishankar } 50667399932SMahesh Ravishankar newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(), 50767399932SMahesh Ravishankar targetShapes[idx], reassociations[idx], 50867399932SMahesh Ravishankar options.rankReductionStrategy)); 509e07149c9SMatthias Springer } 5102b0c8546SMaheshRavishankar 51167399932SMahesh Ravishankar // 5. Create the `linalg.generic` operation with the new operands, 51267399932SMahesh Ravishankar // indexing maps, iterator types and result types. 513e07149c9SMatthias Springer ArrayRef<Value> newInputs = 514e07149c9SMatthias Springer ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs()); 515e07149c9SMatthias Springer ArrayRef<Value> newOutputs = 516e07149c9SMatthias Springer ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits()); 517e07149c9SMatthias Springer SmallVector<Type> resultTypes; 518f358c372STobias Gysi resultTypes.reserve(genericOp.getNumResults()); 519f358c372STobias Gysi for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults())) 520e07149c9SMatthias Springer resultTypes.push_back(newOutputs[i].getType()); 52167399932SMahesh Ravishankar GenericOp replacementOp = 52267399932SMahesh Ravishankar rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs, 52367399932SMahesh Ravishankar newIndexingMaps, newIteratorTypes); 52467399932SMahesh Ravishankar rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(), 525d3b3f765SJacques Pienaar replacementOp.getRegion().begin()); 52667399932SMahesh Ravishankar // 5a. Replace `linalg.index` operations that refer to the dropped unit 52767399932SMahesh Ravishankar // dimensions. 52867399932SMahesh Ravishankar replaceUnitDimIndexOps(replacementOp, unitDims, rewriter); 5292b0c8546SMaheshRavishankar 53067399932SMahesh Ravishankar // 6. If any result type changes, insert a reshape/slice to convert from the 531f3b4e47bSLongsheng Mou // original type to the new type. 532e07149c9SMatthias Springer SmallVector<Value> resultReplacements; 53367399932SMahesh Ravishankar for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) { 53467399932SMahesh Ravishankar unsigned opOperandIndex = index + replacementOp.getNumDpsInputs(); 53567399932SMahesh Ravishankar Value origDest = genericOp.getDpsInitOperand(index)->get(); 53667399932SMahesh Ravishankar if (!collapsed[opOperandIndex]) { 53767399932SMahesh Ravishankar resultReplacements.push_back(result); 538e07149c9SMatthias Springer continue; 5392b0c8546SMaheshRavishankar } 54097069a86SGaurav Shukla Value expandedValue = expandValue(rewriter, loc, result, origDest, 54167399932SMahesh Ravishankar reassociations[opOperandIndex], 54297069a86SGaurav Shukla options.rankReductionStrategy); 54397069a86SGaurav Shukla resultReplacements.push_back(expandedValue); 544e07149c9SMatthias Springer } 545e07149c9SMatthias Springer 5464dbaef6dSMaheshRavishankar return DropUnitDimsResult{replacementOp, resultReplacements}; 5472b0c8546SMaheshRavishankar } 548e07149c9SMatthias Springer 54967399932SMahesh Ravishankar namespace { 55067399932SMahesh Ravishankar struct DropUnitDims : public OpRewritePattern<GenericOp> { 55167399932SMahesh Ravishankar DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {}, 55267399932SMahesh Ravishankar PatternBenefit benefit = 1) 55367399932SMahesh Ravishankar : OpRewritePattern(context, benefit), options(std::move(options)) {} 55467399932SMahesh Ravishankar 55567399932SMahesh Ravishankar LogicalResult matchAndRewrite(GenericOp genericOp, 55667399932SMahesh Ravishankar PatternRewriter &rewriter) const override { 5574dbaef6dSMaheshRavishankar FailureOr<DropUnitDimsResult> result = 5584dbaef6dSMaheshRavishankar dropUnitDims(rewriter, genericOp, options); 5594dbaef6dSMaheshRavishankar if (failed(result)) { 5604dbaef6dSMaheshRavishankar return failure(); 5614dbaef6dSMaheshRavishankar } 5624dbaef6dSMaheshRavishankar rewriter.replaceOp(genericOp, result->replacements); 5634dbaef6dSMaheshRavishankar return success(); 56467399932SMahesh Ravishankar } 56567399932SMahesh Ravishankar 566e07149c9SMatthias Springer private: 56767399932SMahesh Ravishankar ControlDropUnitDims options; 5682b0c8546SMaheshRavishankar }; 569fd15e2b8SMaheshRavishankar } // namespace 570f0a2fe7fSMaheshRavishankar 57160e562d1SQuinn Dawkins //===---------------------------------------------------------------------===// 57260e562d1SQuinn Dawkins // Drop dimensions that are unit-extents within tensor operations. 57360e562d1SQuinn Dawkins //===---------------------------------------------------------------------===// 57460e562d1SQuinn Dawkins 57560e562d1SQuinn Dawkins namespace { 57660e562d1SQuinn Dawkins struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { 57760e562d1SQuinn Dawkins DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {}, 57860e562d1SQuinn Dawkins PatternBenefit benefit = 1) 57960e562d1SQuinn Dawkins : OpRewritePattern(context, benefit), options(std::move(options)) {} 58060e562d1SQuinn Dawkins 58160e562d1SQuinn Dawkins LogicalResult matchAndRewrite(tensor::PadOp padOp, 58260e562d1SQuinn Dawkins PatternRewriter &rewriter) const override { 58360e562d1SQuinn Dawkins // 1a. Get the allowed list of dimensions to drop from the `options`. 58460e562d1SQuinn Dawkins SmallVector<unsigned> allowedUnitDims = options.controlFn(padOp); 58560e562d1SQuinn Dawkins if (allowedUnitDims.empty()) { 58660e562d1SQuinn Dawkins return rewriter.notifyMatchFailure( 58760e562d1SQuinn Dawkins padOp, "control function returns no allowed unit dims to prune"); 58860e562d1SQuinn Dawkins } 58960e562d1SQuinn Dawkins 59060e562d1SQuinn Dawkins if (padOp.getSourceType().getEncoding()) { 59160e562d1SQuinn Dawkins return rewriter.notifyMatchFailure( 59260e562d1SQuinn Dawkins padOp, "cannot collapse dims of tensor with encoding"); 59360e562d1SQuinn Dawkins } 59460e562d1SQuinn Dawkins 59560e562d1SQuinn Dawkins // Fail for non-constant padding values. The body of the pad could 59660e562d1SQuinn Dawkins // depend on the padding indices and/or properties of the padded 59760e562d1SQuinn Dawkins // tensor so for now we fail. 59860e562d1SQuinn Dawkins // TODO: Support non-constant padding values. 59960e562d1SQuinn Dawkins Value paddingVal = padOp.getConstantPaddingValue(); 60060e562d1SQuinn Dawkins if (!paddingVal) { 60160e562d1SQuinn Dawkins return rewriter.notifyMatchFailure( 60260e562d1SQuinn Dawkins padOp, "unimplemented: non-constant padding value"); 60360e562d1SQuinn Dawkins } 60460e562d1SQuinn Dawkins 60560e562d1SQuinn Dawkins ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape(); 60660e562d1SQuinn Dawkins int64_t padRank = sourceShape.size(); 60760e562d1SQuinn Dawkins 60860e562d1SQuinn Dawkins auto isStaticZero = [](OpFoldResult f) { 60960e562d1SQuinn Dawkins std::optional<int64_t> maybeInt = getConstantIntValue(f); 61060e562d1SQuinn Dawkins return maybeInt && *maybeInt == 0; 61160e562d1SQuinn Dawkins }; 61260e562d1SQuinn Dawkins 61360e562d1SQuinn Dawkins llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(), 61460e562d1SQuinn Dawkins allowedUnitDims.end()); 61560e562d1SQuinn Dawkins llvm::SmallDenseSet<unsigned> unitDims; 61660e562d1SQuinn Dawkins SmallVector<int64_t> newShape; 61760e562d1SQuinn Dawkins SmallVector<OpFoldResult> newLowPad; 61860e562d1SQuinn Dawkins SmallVector<OpFoldResult> newHighPad; 61960e562d1SQuinn Dawkins for (const auto [dim, size, low, high] : 62060e562d1SQuinn Dawkins zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape, 62160e562d1SQuinn Dawkins padOp.getMixedLowPad(), padOp.getMixedHighPad())) { 62260e562d1SQuinn Dawkins if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) && 62360e562d1SQuinn Dawkins isStaticZero(high)) { 62460e562d1SQuinn Dawkins unitDims.insert(dim); 62560e562d1SQuinn Dawkins } else { 62660e562d1SQuinn Dawkins newShape.push_back(size); 62760e562d1SQuinn Dawkins newLowPad.push_back(low); 62860e562d1SQuinn Dawkins newHighPad.push_back(high); 62960e562d1SQuinn Dawkins } 63060e562d1SQuinn Dawkins } 63160e562d1SQuinn Dawkins 63260e562d1SQuinn Dawkins if (unitDims.empty()) { 63360e562d1SQuinn Dawkins return rewriter.notifyMatchFailure(padOp, "no unit dims to collapse"); 63460e562d1SQuinn Dawkins } 63560e562d1SQuinn Dawkins 63660e562d1SQuinn Dawkins ReassociationIndices reassociationGroup; 63760e562d1SQuinn Dawkins SmallVector<ReassociationIndices> reassociationMap; 63860e562d1SQuinn Dawkins int64_t dim = 0; 63960e562d1SQuinn Dawkins while (dim < padRank && unitDims.contains(dim)) 64060e562d1SQuinn Dawkins reassociationGroup.push_back(dim++); 64160e562d1SQuinn Dawkins while (dim < padRank) { 64260e562d1SQuinn Dawkins assert(!unitDims.contains(dim) && "expected non unit-extent"); 64360e562d1SQuinn Dawkins reassociationGroup.push_back(dim); 64460e562d1SQuinn Dawkins dim++; 64560e562d1SQuinn Dawkins // Fold all following dimensions that are unit-extent. 64660e562d1SQuinn Dawkins while (dim < padRank && unitDims.contains(dim)) 64760e562d1SQuinn Dawkins reassociationGroup.push_back(dim++); 64860e562d1SQuinn Dawkins reassociationMap.push_back(reassociationGroup); 64960e562d1SQuinn Dawkins reassociationGroup.clear(); 65060e562d1SQuinn Dawkins } 65160e562d1SQuinn Dawkins 65260e562d1SQuinn Dawkins Value collapsedSource = 65360e562d1SQuinn Dawkins collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape, 65460e562d1SQuinn Dawkins reassociationMap, options.rankReductionStrategy); 65560e562d1SQuinn Dawkins 65660e562d1SQuinn Dawkins auto newPadOp = rewriter.create<tensor::PadOp>( 65760e562d1SQuinn Dawkins padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad, 65860e562d1SQuinn Dawkins newHighPad, paddingVal, padOp.getNofold()); 65960e562d1SQuinn Dawkins 66060e562d1SQuinn Dawkins Value dest = padOp.getResult(); 66160e562d1SQuinn Dawkins if (options.rankReductionStrategy == 66260e562d1SQuinn Dawkins ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { 66360e562d1SQuinn Dawkins SmallVector<OpFoldResult> expandedSizes; 66460e562d1SQuinn Dawkins int64_t numUnitDims = 0; 66560e562d1SQuinn Dawkins for (auto dim : llvm::seq(static_cast<int64_t>(0), padRank)) { 66660e562d1SQuinn Dawkins if (unitDims.contains(dim)) { 66760e562d1SQuinn Dawkins expandedSizes.push_back(rewriter.getIndexAttr(1)); 66860e562d1SQuinn Dawkins numUnitDims++; 66960e562d1SQuinn Dawkins continue; 67060e562d1SQuinn Dawkins } 67160e562d1SQuinn Dawkins expandedSizes.push_back(tensor::getMixedSize( 67260e562d1SQuinn Dawkins rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims)); 67360e562d1SQuinn Dawkins } 67460e562d1SQuinn Dawkins dest = rewriter.create<tensor::EmptyOp>( 67560e562d1SQuinn Dawkins padOp.getLoc(), expandedSizes, 67660e562d1SQuinn Dawkins padOp.getResultType().getElementType()); 67760e562d1SQuinn Dawkins } 67860e562d1SQuinn Dawkins 67960e562d1SQuinn Dawkins Value expandedValue = 68060e562d1SQuinn Dawkins expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest, 68160e562d1SQuinn Dawkins reassociationMap, options.rankReductionStrategy); 68260e562d1SQuinn Dawkins rewriter.replaceOp(padOp, expandedValue); 68360e562d1SQuinn Dawkins return success(); 68460e562d1SQuinn Dawkins } 68560e562d1SQuinn Dawkins 68660e562d1SQuinn Dawkins private: 68760e562d1SQuinn Dawkins ControlDropUnitDims options; 68860e562d1SQuinn Dawkins }; 68960e562d1SQuinn Dawkins } // namespace 69060e562d1SQuinn Dawkins 691fd15e2b8SMaheshRavishankar namespace { 692060208b4SMatthias Springer /// Convert `extract_slice` operations to rank-reduced versions. 693df5c981bSNicolas Vasilache struct RankReducedExtractSliceOp 694060208b4SMatthias Springer : public OpRewritePattern<tensor::ExtractSliceOp> { 695060208b4SMatthias Springer using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; 696f0a2fe7fSMaheshRavishankar 697060208b4SMatthias Springer LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, 698f0a2fe7fSMaheshRavishankar PatternRewriter &rewriter) const override { 699060208b4SMatthias Springer RankedTensorType resultType = sliceOp.getType(); 70082ab0f7fSQuinn Dawkins SmallVector<OpFoldResult> targetShape; 70182ab0f7fSQuinn Dawkins for (auto size : resultType.getShape()) 70282ab0f7fSQuinn Dawkins targetShape.push_back(rewriter.getIndexAttr(size)); 70382ab0f7fSQuinn Dawkins auto reassociation = getReassociationMapForFoldingUnitDims(targetShape); 704fd15e2b8SMaheshRavishankar if (!reassociation || 705fd15e2b8SMaheshRavishankar reassociation->size() == static_cast<size_t>(resultType.getRank())) 706f0a2fe7fSMaheshRavishankar return failure(); 70782ab0f7fSQuinn Dawkins 70882ab0f7fSQuinn Dawkins SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); 70982ab0f7fSQuinn Dawkins SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides(); 71082ab0f7fSQuinn Dawkins SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); 7115550c821STres Popp auto rankReducedType = cast<RankedTensorType>( 712741f8f2bSNicolas Vasilache tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( 713741f8f2bSNicolas Vasilache reassociation->size(), sliceOp.getSourceType(), offsets, sizes, 7145550c821STres Popp strides)); 715f0a2fe7fSMaheshRavishankar 716060208b4SMatthias Springer Location loc = sliceOp.getLoc(); 717060208b4SMatthias Springer Value newSlice = rewriter.create<tensor::ExtractSliceOp>( 71804235d07SJacques Pienaar loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides); 719b618880eSAlexander Belyaev rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( 720b618880eSAlexander Belyaev sliceOp, resultType, newSlice, *reassociation); 721f0a2fe7fSMaheshRavishankar return success(); 722f0a2fe7fSMaheshRavishankar } 723f0a2fe7fSMaheshRavishankar }; 724f0a2fe7fSMaheshRavishankar 725060208b4SMatthias Springer /// Convert `insert_slice` operations to rank-reduced versions. 726df5c981bSNicolas Vasilache /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp. 727df5c981bSNicolas Vasilache template <typename InsertOpTy> 728df5c981bSNicolas Vasilache struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> { 729df5c981bSNicolas Vasilache using OpRewritePattern<InsertOpTy>::OpRewritePattern; 730fd15e2b8SMaheshRavishankar 731df5c981bSNicolas Vasilache LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, 732fd15e2b8SMaheshRavishankar PatternRewriter &rewriter) const override { 733df5c981bSNicolas Vasilache RankedTensorType sourceType = insertSliceOp.getSourceType(); 73482ab0f7fSQuinn Dawkins SmallVector<OpFoldResult> targetShape; 73582ab0f7fSQuinn Dawkins for (auto size : sourceType.getShape()) 73682ab0f7fSQuinn Dawkins targetShape.push_back(rewriter.getIndexAttr(size)); 73782ab0f7fSQuinn Dawkins auto reassociation = getReassociationMapForFoldingUnitDims(targetShape); 738fd15e2b8SMaheshRavishankar if (!reassociation || 739fd15e2b8SMaheshRavishankar reassociation->size() == static_cast<size_t>(sourceType.getRank())) 740fd15e2b8SMaheshRavishankar return failure(); 74182ab0f7fSQuinn Dawkins 742df5c981bSNicolas Vasilache Location loc = insertSliceOp.getLoc(); 743df5c981bSNicolas Vasilache tensor::CollapseShapeOp reshapedSource; 744df5c981bSNicolas Vasilache { 745df5c981bSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter); 74667399932SMahesh Ravishankar // The only difference between InsertSliceOp and ParallelInsertSliceOp 74767399932SMahesh Ravishankar // is the insertion point is just before the ParallelCombiningOp in the 748df5c981bSNicolas Vasilache // parallel case. 749df5c981bSNicolas Vasilache if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value) 750df5c981bSNicolas Vasilache rewriter.setInsertionPoint(insertSliceOp->getParentOp()); 751df5c981bSNicolas Vasilache reshapedSource = rewriter.create<tensor::CollapseShapeOp>( 752df5c981bSNicolas Vasilache loc, insertSliceOp.getSource(), *reassociation); 753df5c981bSNicolas Vasilache } 754df5c981bSNicolas Vasilache rewriter.replaceOpWithNewOp<InsertOpTy>( 755df5c981bSNicolas Vasilache insertSliceOp, reshapedSource, insertSliceOp.getDest(), 756df5c981bSNicolas Vasilache insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), 757df5c981bSNicolas Vasilache insertSliceOp.getMixedStrides()); 758fd15e2b8SMaheshRavishankar return success(); 759fd15e2b8SMaheshRavishankar } 760fd15e2b8SMaheshRavishankar }; 761b62f9f44SMaheshRavishankar } // namespace 762b62f9f44SMaheshRavishankar 7632b0c8546SMaheshRavishankar /// Patterns that are used to canonicalize the use of unit-extent dims for 7642b0c8546SMaheshRavishankar /// broadcasting. 76567399932SMahesh Ravishankar static void 76667399932SMahesh Ravishankar populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, 76767399932SMahesh Ravishankar ControlDropUnitDims &options) { 7683a506b31SChris Lattner auto *context = patterns.getContext(); 76967399932SMahesh Ravishankar patterns.add<DropUnitDims>(context, options); 77060e562d1SQuinn Dawkins patterns.add<DropPadUnitDims>(context, options); 771e07149c9SMatthias Springer // TODO: Patterns unrelated to unit dim folding should be factored out. 77267399932SMahesh Ravishankar patterns.add<RankReducedExtractSliceOp, 773df5c981bSNicolas Vasilache RankReducedInsertSliceOp<tensor::InsertSliceOp>, 774d2b070d3SMatthias Springer RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>( 775d2b070d3SMatthias Springer context); 776b618880eSAlexander Belyaev linalg::FillOp::getCanonicalizationPatterns(patterns, context); 777b618880eSAlexander Belyaev tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 77881ca5aa4SMatthias Springer tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); 779b618880eSAlexander Belyaev tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 780f6fb0a4fSAlexander Belyaev tensor::populateFoldTensorEmptyPatterns(patterns); 7814abccd39SMatthias Springer memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); 7829b16d9d2SHanhan Wang memref::populateResolveShapedTypeResultDimsPatterns(patterns); 7832b0c8546SMaheshRavishankar } 7842b0c8546SMaheshRavishankar 78567399932SMahesh Ravishankar static void 78667399932SMahesh Ravishankar populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, 78767399932SMahesh Ravishankar ControlDropUnitDims &options) { 788e07149c9SMatthias Springer auto *context = patterns.getContext(); 78967399932SMahesh Ravishankar patterns.add<DropUnitDims>(context, options); 79060e562d1SQuinn Dawkins patterns.add<DropPadUnitDims>(context, options); 7916b76c4eaSMatthias Springer // TODO: Patterns unrelated to unit dim folding should be factored out. 7926b76c4eaSMatthias Springer linalg::FillOp::getCanonicalizationPatterns(patterns, context); 7936b76c4eaSMatthias Springer tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); 7946b76c4eaSMatthias Springer tensor::populateFoldTensorEmptyPatterns(patterns); 7954abccd39SMatthias Springer memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); 7966b76c4eaSMatthias Springer memref::populateResolveShapedTypeResultDimsPatterns(patterns); 797e07149c9SMatthias Springer } 798e07149c9SMatthias Springer 79967399932SMahesh Ravishankar void mlir::linalg::populateFoldUnitExtentDimsPatterns( 80067399932SMahesh Ravishankar RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) { 80167399932SMahesh Ravishankar if (options.rankReductionStrategy == 80267399932SMahesh Ravishankar linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) { 80367399932SMahesh Ravishankar populateFoldUnitExtentDimsViaSlicesPatterns(patterns, options); 80467399932SMahesh Ravishankar } else if (options.rankReductionStrategy == 80567399932SMahesh Ravishankar linalg::ControlDropUnitDims::RankReductionStrategy:: 80667399932SMahesh Ravishankar ReassociativeReshape) { 80767399932SMahesh Ravishankar populateFoldUnitExtentDimsViaReshapesPatterns(patterns, options); 80867399932SMahesh Ravishankar } 80967399932SMahesh Ravishankar } 81067399932SMahesh Ravishankar 811d2b070d3SMatthias Springer void mlir::linalg::populateMoveInitOperandsToInputPattern( 812d2b070d3SMatthias Springer RewritePatternSet &patterns) { 813d2b070d3SMatthias Springer patterns.add<MoveInitOperandsToInput>(patterns.getContext()); 814d2b070d3SMatthias Springer } 815d2b070d3SMatthias Springer 8162b0c8546SMaheshRavishankar namespace { 8172b0c8546SMaheshRavishankar /// Pass that removes unit-extent dims within generic ops. 8182b0c8546SMaheshRavishankar struct LinalgFoldUnitExtentDimsPass 8191e98d488SQuinn Dawkins : public impl::LinalgFoldUnitExtentDimsPassBase< 8201e98d488SQuinn Dawkins LinalgFoldUnitExtentDimsPass> { 8211e98d488SQuinn Dawkins using impl::LinalgFoldUnitExtentDimsPassBase< 8221e98d488SQuinn Dawkins LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase; 823c10995a8SStella Laurenzo void runOnOperation() override { 824c10995a8SStella Laurenzo Operation *op = getOperation(); 825c10995a8SStella Laurenzo MLIRContext *context = op->getContext(); 826dc4e913bSChris Lattner RewritePatternSet patterns(context); 82767399932SMahesh Ravishankar ControlDropUnitDims options; 82867399932SMahesh Ravishankar if (useRankReducingSlices) { 82967399932SMahesh Ravishankar options.rankReductionStrategy = linalg::ControlDropUnitDims:: 83067399932SMahesh Ravishankar RankReductionStrategy::ExtractInsertSlice; 831e07149c9SMatthias Springer } 83267399932SMahesh Ravishankar linalg::populateFoldUnitExtentDimsPatterns(patterns, options); 83367399932SMahesh Ravishankar populateMoveInitOperandsToInputPattern(patterns); 834*09dfc571SJacques Pienaar (void)applyPatternsGreedily(op, std::move(patterns)); 8352b0c8546SMaheshRavishankar } 8362b0c8546SMaheshRavishankar }; 837431213c9Ssrcarroll 8382b0c8546SMaheshRavishankar } // namespace 839431213c9Ssrcarroll 840431213c9Ssrcarroll namespace { 841431213c9Ssrcarroll 842431213c9Ssrcarroll /// Returns reassociation indices for collapsing/expanding a 843431213c9Ssrcarroll /// tensor of rank `rank` at position `pos`. 844431213c9Ssrcarroll static SmallVector<ReassociationIndices> 845431213c9Ssrcarroll getReassociationForReshapeAtDim(int64_t rank, int64_t pos) { 846431213c9Ssrcarroll SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1}); 847431213c9Ssrcarroll bool lastDim = pos == rank - 1; 848431213c9Ssrcarroll if (rank > 2) { 849431213c9Ssrcarroll for (int64_t i = 0; i < rank - 1; i++) { 850431213c9Ssrcarroll if (i == pos || (lastDim && i == pos - 1)) 851431213c9Ssrcarroll reassociation[i] = ReassociationIndices{i, i + 1}; 852431213c9Ssrcarroll else if (i < pos) 853431213c9Ssrcarroll reassociation[i] = ReassociationIndices{i}; 854431213c9Ssrcarroll else 855431213c9Ssrcarroll reassociation[i] = ReassociationIndices{i + 1}; 856431213c9Ssrcarroll } 857431213c9Ssrcarroll } 858431213c9Ssrcarroll return reassociation; 859431213c9Ssrcarroll } 860431213c9Ssrcarroll 861431213c9Ssrcarroll /// Returns a collapsed `val` where the collapsing occurs at dim `pos`. 862431213c9Ssrcarroll /// If `pos < 0`, then don't collapse. 863431213c9Ssrcarroll static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val, 864431213c9Ssrcarroll int64_t pos) { 865431213c9Ssrcarroll if (pos < 0) 866431213c9Ssrcarroll return val; 867431213c9Ssrcarroll auto valType = cast<ShapedType>(val.getType()); 868431213c9Ssrcarroll SmallVector<int64_t> collapsedShape(valType.getShape()); 869431213c9Ssrcarroll collapsedShape.erase(collapsedShape.begin() + pos); 870431213c9Ssrcarroll return collapseValue( 871431213c9Ssrcarroll rewriter, val.getLoc(), val, collapsedShape, 872431213c9Ssrcarroll getReassociationForReshapeAtDim(valType.getRank(), pos), 873431213c9Ssrcarroll ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape); 874431213c9Ssrcarroll } 875431213c9Ssrcarroll 876431213c9Ssrcarroll /// Base class for all rank reduction patterns for contraction ops 877431213c9Ssrcarroll /// with unit dimensions. All patterns should convert one named op 878431213c9Ssrcarroll /// to another named op. Intended to reduce only one iteration space dim 879431213c9Ssrcarroll /// at a time. 880431213c9Ssrcarroll /// Reducing multiple dims will happen with recusive application of 881431213c9Ssrcarroll /// pattern rewrites. 882431213c9Ssrcarroll template <typename FromOpTy, typename ToOpTy> 883431213c9Ssrcarroll struct RankReduceContractionOps : OpRewritePattern<FromOpTy> { 884431213c9Ssrcarroll using OpRewritePattern<FromOpTy>::OpRewritePattern; 885431213c9Ssrcarroll 886431213c9Ssrcarroll /// Collapse all collapsable operands. 887431213c9Ssrcarroll SmallVector<Value> 888431213c9Ssrcarroll collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands, 889431213c9Ssrcarroll ArrayRef<int64_t> operandCollapseDims) const { 890431213c9Ssrcarroll assert(operandCollapseDims.size() == 3 && operands.size() == 3 && 891431213c9Ssrcarroll "expected 3 operands and dims"); 892431213c9Ssrcarroll return llvm::map_to_vector( 893431213c9Ssrcarroll llvm::zip(operands, operandCollapseDims), [&](auto pair) { 894431213c9Ssrcarroll return collapseSingletonDimAt(rewriter, std::get<0>(pair), 895431213c9Ssrcarroll std::get<1>(pair)); 896431213c9Ssrcarroll }); 897431213c9Ssrcarroll } 898431213c9Ssrcarroll 899431213c9Ssrcarroll /// Expand result tensor. 900431213c9Ssrcarroll Value expandResult(PatternRewriter &rewriter, Value result, 901431213c9Ssrcarroll RankedTensorType expandedType, int64_t dim) const { 902431213c9Ssrcarroll return rewriter.create<tensor::ExpandShapeOp>( 903431213c9Ssrcarroll result.getLoc(), expandedType, result, 904431213c9Ssrcarroll getReassociationForReshapeAtDim(expandedType.getRank(), dim)); 905431213c9Ssrcarroll } 906431213c9Ssrcarroll 907431213c9Ssrcarroll LogicalResult matchAndRewrite(FromOpTy contractionOp, 908431213c9Ssrcarroll PatternRewriter &rewriter) const override { 909431213c9Ssrcarroll 910431213c9Ssrcarroll auto loc = contractionOp.getLoc(); 911431213c9Ssrcarroll auto inputs = contractionOp.getDpsInputs(); 912431213c9Ssrcarroll auto inits = contractionOp.getDpsInits(); 913431213c9Ssrcarroll if (inputs.size() != 2 || inits.size() != 1) 914431213c9Ssrcarroll return rewriter.notifyMatchFailure(contractionOp, 915431213c9Ssrcarroll "expected 2 inputs and 1 init"); 916431213c9Ssrcarroll auto lhs = inputs[0]; 917431213c9Ssrcarroll auto rhs = inputs[1]; 918431213c9Ssrcarroll auto init = inits[0]; 919431213c9Ssrcarroll SmallVector<Value> operands{lhs, rhs, init}; 920431213c9Ssrcarroll 921431213c9Ssrcarroll SmallVector<int64_t> operandUnitDims; 922431213c9Ssrcarroll if (failed(getOperandUnitDims(contractionOp, operandUnitDims))) 923431213c9Ssrcarroll return rewriter.notifyMatchFailure(contractionOp, 924431213c9Ssrcarroll "no reducable dims found"); 925431213c9Ssrcarroll 926431213c9Ssrcarroll SmallVector<Value> collapsedOperands = 927431213c9Ssrcarroll collapseOperands(rewriter, operands, operandUnitDims); 928431213c9Ssrcarroll Value collapsedLhs = collapsedOperands[0]; 929431213c9Ssrcarroll Value collapsedRhs = collapsedOperands[1]; 930431213c9Ssrcarroll Value collapsedInit = collapsedOperands[2]; 931431213c9Ssrcarroll SmallVector<Type, 1> collapsedResultTy; 932431213c9Ssrcarroll if (isa<RankedTensorType>(collapsedInit.getType())) 933431213c9Ssrcarroll collapsedResultTy.push_back(collapsedInit.getType()); 934431213c9Ssrcarroll auto collapsedOp = rewriter.create<ToOpTy>( 935431213c9Ssrcarroll loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs}, 936431213c9Ssrcarroll ValueRange{collapsedInit}); 937431213c9Ssrcarroll for (auto attr : contractionOp->getAttrs()) { 938431213c9Ssrcarroll if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) 939431213c9Ssrcarroll continue; 940431213c9Ssrcarroll collapsedOp->setAttr(attr.getName(), attr.getValue()); 941431213c9Ssrcarroll } 942431213c9Ssrcarroll 943431213c9Ssrcarroll auto results = contractionOp.getResults(); 944431213c9Ssrcarroll assert(results.size() < 2 && "expected at most one result"); 945431213c9Ssrcarroll if (results.empty()) { 946431213c9Ssrcarroll rewriter.replaceOp(contractionOp, collapsedOp); 947431213c9Ssrcarroll } else { 948431213c9Ssrcarroll rewriter.replaceOp( 949431213c9Ssrcarroll contractionOp, 950431213c9Ssrcarroll expandResult(rewriter, collapsedOp.getResultTensors()[0], 951431213c9Ssrcarroll cast<RankedTensorType>(results[0].getType()), 952431213c9Ssrcarroll operandUnitDims[2])); 953431213c9Ssrcarroll } 954431213c9Ssrcarroll 955431213c9Ssrcarroll return success(); 956431213c9Ssrcarroll } 957431213c9Ssrcarroll 958431213c9Ssrcarroll /// Populate `operandUnitDims` with 3 indices indicating the unit dim 959431213c9Ssrcarroll /// for each operand that should be collapsed in this pattern. If an 960431213c9Ssrcarroll /// operand shouldn't be collapsed, the index should be negative. 961431213c9Ssrcarroll virtual LogicalResult 962431213c9Ssrcarroll getOperandUnitDims(LinalgOp op, 963431213c9Ssrcarroll SmallVectorImpl<int64_t> &operandUnitDims) const = 0; 964431213c9Ssrcarroll }; 965431213c9Ssrcarroll 966431213c9Ssrcarroll /// Patterns for unbatching batched contraction ops 967431213c9Ssrcarroll template <typename FromOpTy, typename ToOpTy> 968431213c9Ssrcarroll struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> { 969431213c9Ssrcarroll using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps; 970431213c9Ssrcarroll 971431213c9Ssrcarroll /// Look for unit batch dims to collapse. 972431213c9Ssrcarroll LogicalResult 973431213c9Ssrcarroll getOperandUnitDims(LinalgOp op, 974431213c9Ssrcarroll SmallVectorImpl<int64_t> &operandUnitDims) const override { 975431213c9Ssrcarroll FailureOr<ContractionDimensions> maybeContractionDims = 976431213c9Ssrcarroll inferContractionDims(op); 977431213c9Ssrcarroll if (failed(maybeContractionDims)) { 978431213c9Ssrcarroll LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims"); 979431213c9Ssrcarroll return failure(); 980431213c9Ssrcarroll } 981431213c9Ssrcarroll ContractionDimensions contractionDims = maybeContractionDims.value(); 982431213c9Ssrcarroll 983431213c9Ssrcarroll if (contractionDims.batch.size() != 1) 984431213c9Ssrcarroll return failure(); 985431213c9Ssrcarroll auto batchDim = contractionDims.batch[0]; 986431213c9Ssrcarroll SmallVector<std::pair<Value, unsigned>, 3> bOperands; 987431213c9Ssrcarroll op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands); 988431213c9Ssrcarroll if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) { 989431213c9Ssrcarroll return cast<ShapedType>(std::get<0>(pair).getType()) 990431213c9Ssrcarroll .getShape()[std::get<1>(pair)] != 1; 991431213c9Ssrcarroll })) { 992431213c9Ssrcarroll LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found"); 993431213c9Ssrcarroll return failure(); 994431213c9Ssrcarroll } 995431213c9Ssrcarroll 996431213c9Ssrcarroll operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]), 997431213c9Ssrcarroll std::get<1>(bOperands[1]), 998431213c9Ssrcarroll std::get<1>(bOperands[2])}; 999431213c9Ssrcarroll return success(); 1000431213c9Ssrcarroll } 1001431213c9Ssrcarroll }; 1002431213c9Ssrcarroll 1003431213c9Ssrcarroll /// Patterns for reducing non-batch dimensions 1004431213c9Ssrcarroll template <typename FromOpTy, typename ToOpTy> 1005431213c9Ssrcarroll struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> { 1006431213c9Ssrcarroll using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps; 1007431213c9Ssrcarroll 1008431213c9Ssrcarroll /// Helper for determining whether the lhs/init or rhs/init are reduced. 1009431213c9Ssrcarroll static bool constexpr reduceLeft = 1010431213c9Ssrcarroll (std::is_same_v<FromOpTy, BatchMatmulOp> && 1011431213c9Ssrcarroll std::is_same_v<ToOpTy, BatchVecmatOp>) || 1012431213c9Ssrcarroll (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> && 1013431213c9Ssrcarroll std::is_same_v<ToOpTy, BatchVecmatOp>) || 1014431213c9Ssrcarroll (std::is_same_v<FromOpTy, MatmulOp> && 1015431213c9Ssrcarroll std::is_same_v<ToOpTy, VecmatOp>) || 1016431213c9Ssrcarroll (std::is_same_v<FromOpTy, MatmulTransposeAOp> && 1017431213c9Ssrcarroll std::is_same_v<ToOpTy, VecmatOp>) || 1018431213c9Ssrcarroll (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>); 1019431213c9Ssrcarroll 1020431213c9Ssrcarroll /// Look for non-batch spatial dims to collapse. 1021431213c9Ssrcarroll LogicalResult 1022431213c9Ssrcarroll getOperandUnitDims(LinalgOp op, 1023431213c9Ssrcarroll SmallVectorImpl<int64_t> &operandUnitDims) const override { 1024431213c9Ssrcarroll FailureOr<ContractionDimensions> maybeContractionDims = 1025431213c9Ssrcarroll inferContractionDims(op); 1026431213c9Ssrcarroll if (failed(maybeContractionDims)) { 1027431213c9Ssrcarroll LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims"); 1028431213c9Ssrcarroll return failure(); 1029431213c9Ssrcarroll } 1030431213c9Ssrcarroll ContractionDimensions contractionDims = maybeContractionDims.value(); 1031431213c9Ssrcarroll 1032431213c9Ssrcarroll if constexpr (reduceLeft) { 1033431213c9Ssrcarroll auto m = contractionDims.m[0]; 1034431213c9Ssrcarroll SmallVector<std::pair<Value, unsigned>, 2> mOperands; 1035431213c9Ssrcarroll op.mapIterationSpaceDimToAllOperandDims(m, mOperands); 1036431213c9Ssrcarroll if (mOperands.size() != 2) 1037431213c9Ssrcarroll return failure(); 1038431213c9Ssrcarroll if (llvm::all_of(mOperands, [](auto pair) { 1039431213c9Ssrcarroll return cast<ShapedType>(std::get<0>(pair).getType()) 1040431213c9Ssrcarroll .getShape()[std::get<1>(pair)] == 1; 1041431213c9Ssrcarroll })) { 1042431213c9Ssrcarroll operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1, 1043431213c9Ssrcarroll std::get<1>(mOperands[1])}; 1044431213c9Ssrcarroll return success(); 1045431213c9Ssrcarroll } 1046431213c9Ssrcarroll } else { 1047431213c9Ssrcarroll auto n = contractionDims.n[0]; 1048431213c9Ssrcarroll SmallVector<std::pair<Value, unsigned>, 2> nOperands; 1049431213c9Ssrcarroll op.mapIterationSpaceDimToAllOperandDims(n, nOperands); 1050431213c9Ssrcarroll if (nOperands.size() != 2) 1051431213c9Ssrcarroll return failure(); 1052431213c9Ssrcarroll if (llvm::all_of(nOperands, [](auto pair) { 1053431213c9Ssrcarroll return cast<ShapedType>(std::get<0>(pair).getType()) 1054431213c9Ssrcarroll .getShape()[std::get<1>(pair)] == 1; 1055431213c9Ssrcarroll })) { 1056431213c9Ssrcarroll operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]), 1057431213c9Ssrcarroll std::get<1>(nOperands[1])}; 1058431213c9Ssrcarroll return success(); 1059431213c9Ssrcarroll } 1060431213c9Ssrcarroll } 1061431213c9Ssrcarroll LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found"); 1062431213c9Ssrcarroll return failure(); 1063431213c9Ssrcarroll } 1064431213c9Ssrcarroll }; 1065431213c9Ssrcarroll 1066431213c9Ssrcarroll } // namespace 1067431213c9Ssrcarroll 1068431213c9Ssrcarroll void mlir::linalg::populateContractionOpRankReducingPatterns( 1069431213c9Ssrcarroll RewritePatternSet &patterns) { 1070431213c9Ssrcarroll MLIRContext *context = patterns.getContext(); 1071431213c9Ssrcarroll // Unbatching patterns for unit batch size 1072431213c9Ssrcarroll patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context); 1073431213c9Ssrcarroll patterns 1074431213c9Ssrcarroll .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>( 1075431213c9Ssrcarroll context); 1076431213c9Ssrcarroll patterns 1077431213c9Ssrcarroll .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>( 1078431213c9Ssrcarroll context); 1079431213c9Ssrcarroll patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context); 1080431213c9Ssrcarroll patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context); 1081431213c9Ssrcarroll 1082431213c9Ssrcarroll // Non-batch rank 1 reducing patterns 1083431213c9Ssrcarroll patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context); 1084431213c9Ssrcarroll patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context); 1085431213c9Ssrcarroll patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context); 1086431213c9Ssrcarroll patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context); 1087431213c9Ssrcarroll // Batch rank 1 reducing patterns 1088431213c9Ssrcarroll patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context); 1089431213c9Ssrcarroll patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context); 1090431213c9Ssrcarroll patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>( 1091431213c9Ssrcarroll context); 1092431213c9Ssrcarroll patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>( 1093431213c9Ssrcarroll context); 1094431213c9Ssrcarroll 1095431213c9Ssrcarroll // Non-batch rank 0 reducing patterns 1096431213c9Ssrcarroll patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context); 1097431213c9Ssrcarroll patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context); 1098431213c9Ssrcarroll } 1099