xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
12b0c8546SMaheshRavishankar //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
22b0c8546SMaheshRavishankar //
32b0c8546SMaheshRavishankar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42b0c8546SMaheshRavishankar // See https://llvm.org/LICENSE.txt for license information.
52b0c8546SMaheshRavishankar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62b0c8546SMaheshRavishankar //
72b0c8546SMaheshRavishankar //===----------------------------------------------------------------------===//
82b0c8546SMaheshRavishankar //
92b0c8546SMaheshRavishankar // This file implements patterns/pass to remove usage of unit-extent dimensions
102b0c8546SMaheshRavishankar // to specify broadcasting in favor of more canonical representation of the
112b0c8546SMaheshRavishankar // computation
122b0c8546SMaheshRavishankar //
132b0c8546SMaheshRavishankar //===----------------------------------------------------------------------===//
142b0c8546SMaheshRavishankar 
1567d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h"
1667d0d7acSMichele Scuttari 
1767d0d7acSMichele Scuttari #include "mlir/Dialect/Affine/IR/AffineOps.h"
18abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
19b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
20ea069aebSMaheshRavishankar #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
212b0c8546SMaheshRavishankar #include "mlir/Dialect/Linalg/Utils/Utils.h"
22faafd26cSQuentin Colombet #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
23060208b4SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
24f6fb0a4fSAlexander Belyaev #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
259b16d9d2SHanhan Wang #include "mlir/Dialect/Tensor/Utils/Utils.h"
2697069a86SGaurav Shukla #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
272b0c8546SMaheshRavishankar #include "mlir/IR/AffineExpr.h"
282b0c8546SMaheshRavishankar #include "mlir/IR/AffineMap.h"
296c7be417STres Popp #include "mlir/IR/BuiltinTypes.h"
302b0c8546SMaheshRavishankar #include "mlir/Transforms/FoldUtils.h"
31b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
329b16d9d2SHanhan Wang #include "llvm/ADT/SetVector.h"
332b0c8546SMaheshRavishankar #include "llvm/Support/CommandLine.h"
342b0c8546SMaheshRavishankar #include "llvm/Support/Debug.h"
352b0c8546SMaheshRavishankar 
3667d0d7acSMichele Scuttari namespace mlir {
371e98d488SQuinn Dawkins #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
3867d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h.inc"
3967d0d7acSMichele Scuttari } // namespace mlir
4067d0d7acSMichele Scuttari 
412b0c8546SMaheshRavishankar #define DEBUG_TYPE "linalg-drop-unit-dims"
422b0c8546SMaheshRavishankar 
432b0c8546SMaheshRavishankar using namespace mlir;
442b0c8546SMaheshRavishankar using namespace mlir::linalg;
452b0c8546SMaheshRavishankar 
46e07149c9SMatthias Springer namespace {
47d2b070d3SMatthias Springer /// Pattern to move init operands to ins when all the loops are parallel and
489b16d9d2SHanhan Wang /// blockArgument corresponding to init is used in the region. This is a fix-up
499b16d9d2SHanhan Wang /// when unit reduction dimensions are all folded away. In this context, it
509b16d9d2SHanhan Wang /// becomes a elementwise generic op. E.g., it converts
519b16d9d2SHanhan Wang ///
529b16d9d2SHanhan Wang ///  %0 = tensor.empty() : tensor<1x1xf32>
539b16d9d2SHanhan Wang ///  %1 = linalg.fill
549b16d9d2SHanhan Wang ///    ins(%cst : f32)
559b16d9d2SHanhan Wang ///    outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
569b16d9d2SHanhan Wang ///  %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
579b16d9d2SHanhan Wang ///                                        affine_map<(d0) -> (0, d0)>],
589b16d9d2SHanhan Wang ///                       iterator_types = ["parallel"]}
599b16d9d2SHanhan Wang ///    ins(%arg0 : tensor<1x?x1x1xf32>)
609b16d9d2SHanhan Wang ///    outs(%1 : tensor<1x1xf32>) {
619b16d9d2SHanhan Wang ///  ^bb0(%in: f32, %out: f32):
629b16d9d2SHanhan Wang ///    %3 = arith.addf %in, %out : f32
639b16d9d2SHanhan Wang ///    linalg.yield %3 : f32
649b16d9d2SHanhan Wang ///  } -> tensor<1x1xf32>
659b16d9d2SHanhan Wang ///
669b16d9d2SHanhan Wang ///  into
679b16d9d2SHanhan Wang ///
689b16d9d2SHanhan Wang ///  %0 = tensor.empty() : tensor<1x1xf32>
699b16d9d2SHanhan Wang ///  %1 = linalg.fill
709b16d9d2SHanhan Wang ///    ins(%cst : f32)
719b16d9d2SHanhan Wang ///    outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
729b16d9d2SHanhan Wang ///  %2 = tensor.empty() : tensor<1x1xf32>
739b16d9d2SHanhan Wang ///  %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
749b16d9d2SHanhan Wang ///                                        affine_map<(d0) -> (0, d0)>,
759b16d9d2SHanhan Wang ///                                        affine_map<(d0) -> (0, d0)>],
769b16d9d2SHanhan Wang ///                       iterator_types = ["parallel"]}
779b16d9d2SHanhan Wang ///   ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
789b16d9d2SHanhan Wang ///   outs(%2 : tensor<1x1xf32>) {
799b16d9d2SHanhan Wang ///  ^bb0(%in: f32, %in_0: f32, %out: f32):
809b16d9d2SHanhan Wang ///    %4 = arith.addf %in, %in_0 : f32
819b16d9d2SHanhan Wang ///    linalg.yield %4 : f32
829b16d9d2SHanhan Wang ///  } -> tensor<1x1xf32>
83d2b070d3SMatthias Springer struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
849b16d9d2SHanhan Wang   using OpRewritePattern<GenericOp>::OpRewritePattern;
859b16d9d2SHanhan Wang   LogicalResult matchAndRewrite(GenericOp genericOp,
869b16d9d2SHanhan Wang                                 PatternRewriter &rewriter) const override {
870a8e3dd4SMatthias Springer     if (!genericOp.hasPureTensorSemantics())
889b16d9d2SHanhan Wang       return failure();
899b16d9d2SHanhan Wang     if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
909b16d9d2SHanhan Wang       return failure();
919b16d9d2SHanhan Wang 
920b2197b0SMatthias Springer     auto outputOperands = genericOp.getDpsInitsMutable();
939b16d9d2SHanhan Wang     SetVector<OpOperand *> candidates;
940b2197b0SMatthias Springer     for (OpOperand &op : outputOperands) {
950b2197b0SMatthias Springer       if (genericOp.getMatchingBlockArgument(&op).use_empty())
969b16d9d2SHanhan Wang         continue;
970b2197b0SMatthias Springer       candidates.insert(&op);
989b16d9d2SHanhan Wang     }
999b16d9d2SHanhan Wang 
1009b16d9d2SHanhan Wang     if (candidates.empty())
1019b16d9d2SHanhan Wang       return failure();
1029b16d9d2SHanhan Wang 
1039b16d9d2SHanhan Wang     // Compute the modified indexing maps.
1049b16d9d2SHanhan Wang     int64_t origNumInput = genericOp.getNumDpsInputs();
1050b2197b0SMatthias Springer     SmallVector<Value> newInputOperands = genericOp.getDpsInputs();
1069b16d9d2SHanhan Wang     SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
1079b16d9d2SHanhan Wang     SmallVector<AffineMap> newIndexingMaps;
1089b16d9d2SHanhan Wang     newIndexingMaps.append(indexingMaps.begin(),
1099b16d9d2SHanhan Wang                            std::next(indexingMaps.begin(), origNumInput));
1109b16d9d2SHanhan Wang     for (OpOperand *op : candidates) {
1119b16d9d2SHanhan Wang       newInputOperands.push_back(op->get());
1129b16d9d2SHanhan Wang       newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
1139b16d9d2SHanhan Wang     }
1149b16d9d2SHanhan Wang     newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
1159b16d9d2SHanhan Wang                            indexingMaps.end());
1169b16d9d2SHanhan Wang 
1179b16d9d2SHanhan Wang     Location loc = genericOp.getLoc();
1180b2197b0SMatthias Springer     SmallVector<Value> newOutputOperands =
1190b2197b0SMatthias Springer         llvm::to_vector(genericOp.getDpsInits());
1209b16d9d2SHanhan Wang     for (OpOperand *op : candidates) {
1219b16d9d2SHanhan Wang       OpBuilder::InsertionGuard guard(rewriter);
1229b16d9d2SHanhan Wang       rewriter.setInsertionPointAfterValue(op->get());
1235550c821STres Popp       auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
1249b16d9d2SHanhan Wang       auto empty = rewriter.create<tensor::EmptyOp>(
1256596b0ddSMatthias Springer           loc, tensor::getMixedSizes(rewriter, loc, op->get()), elemType);
1269b16d9d2SHanhan Wang 
1270b2197b0SMatthias Springer       unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
1289b16d9d2SHanhan Wang       newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
1299b16d9d2SHanhan Wang     }
1309b16d9d2SHanhan Wang 
1319b16d9d2SHanhan Wang     auto newOp = rewriter.create<GenericOp>(
1329b16d9d2SHanhan Wang         loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
1339b16d9d2SHanhan Wang         newIndexingMaps, genericOp.getIteratorTypesArray(),
1349b16d9d2SHanhan Wang         /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
1359b16d9d2SHanhan Wang 
1369b16d9d2SHanhan Wang     OpBuilder::InsertionGuard guard(rewriter);
13791d5653eSMatthias Springer     Region &region = newOp.getRegion();
13891d5653eSMatthias Springer     Block *block = rewriter.createBlock(&region);
13991d5653eSMatthias Springer     IRMapping mapper;
1409b16d9d2SHanhan Wang     for (auto bbarg : genericOp.getRegionInputArgs())
1419b16d9d2SHanhan Wang       mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
1429b16d9d2SHanhan Wang 
1439b16d9d2SHanhan Wang     for (OpOperand *op : candidates) {
1449b16d9d2SHanhan Wang       BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
1459b16d9d2SHanhan Wang       mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
1469b16d9d2SHanhan Wang     }
1479b16d9d2SHanhan Wang 
1480b2197b0SMatthias Springer     for (OpOperand &op : outputOperands) {
1490b2197b0SMatthias Springer       BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
1500b2197b0SMatthias Springer       if (candidates.count(&op))
1519b16d9d2SHanhan Wang         block->addArgument(bbarg.getType(), loc);
1529b16d9d2SHanhan Wang       else
1539b16d9d2SHanhan Wang         mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
1549b16d9d2SHanhan Wang     }
1559b16d9d2SHanhan Wang 
1569b16d9d2SHanhan Wang     for (auto &op : genericOp.getBody()->getOperations()) {
1579b16d9d2SHanhan Wang       rewriter.clone(op, mapper);
1589b16d9d2SHanhan Wang     }
1599b16d9d2SHanhan Wang     rewriter.replaceOp(genericOp, newOp.getResults());
1609b16d9d2SHanhan Wang 
1619b16d9d2SHanhan Wang     return success();
1629b16d9d2SHanhan Wang   }
1639b16d9d2SHanhan Wang };
1642b0c8546SMaheshRavishankar } // namespace
1652b0c8546SMaheshRavishankar 
16667399932SMahesh Ravishankar //===---------------------------------------------------------------------===//
16767399932SMahesh Ravishankar // Drop loops that are unit-extents within Linalg operations.
16867399932SMahesh Ravishankar //===---------------------------------------------------------------------===//
1692b0c8546SMaheshRavishankar 
17067399932SMahesh Ravishankar /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
17167399932SMahesh Ravishankar /// broadcasting. For example,
17267399932SMahesh Ravishankar ///
17367399932SMahesh Ravishankar /// ```mlir
17467399932SMahesh Ravishankar /// #accesses = [
17567399932SMahesh Ravishankar ///   affine_map<(d0, d1) -> (0, d1)>,
17667399932SMahesh Ravishankar ///   affine_map<(d0, d1) -> (d0, 0)>,
17767399932SMahesh Ravishankar ///   affine_map<(d0, d1) -> (d0, d1)>
17867399932SMahesh Ravishankar /// ]
17967399932SMahesh Ravishankar ///
18067399932SMahesh Ravishankar /// #trait = {
18167399932SMahesh Ravishankar ///   indexing_maps = #accesses,
18267399932SMahesh Ravishankar ///   iterator_types = ["parallel", "parallel"],
18367399932SMahesh Ravishankar ///   library_call = "some_external_fn"
18467399932SMahesh Ravishankar /// }
18567399932SMahesh Ravishankar ///
18667399932SMahesh Ravishankar /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
18767399932SMahesh Ravishankar /// tensor<5x5xf32>
18867399932SMahesh Ravishankar /// {
18967399932SMahesh Ravishankar ///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
19067399932SMahesh Ravishankar ///        tensor<5xf32> into tensor<1x5xf32>
19167399932SMahesh Ravishankar ///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
19267399932SMahesh Ravishankar ///        tensor<5xf32> into tensor<5x1xf32>
19367399932SMahesh Ravishankar ///   %2 = linalg.generic #trait %0, %1 {
19467399932SMahesh Ravishankar ///        ^bb0(%arg2: f32, %arg3: f32):
19567399932SMahesh Ravishankar ///          %3 = arith.addf %arg2, %arg3 : f32
19667399932SMahesh Ravishankar ///          linalg.yield %3 : f32
19767399932SMahesh Ravishankar ///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
19867399932SMahesh Ravishankar ///   return %2 : tensor<5x5xf32>
19967399932SMahesh Ravishankar /// }
20067399932SMahesh Ravishankar ///
20167399932SMahesh Ravishankar /// would canonicalize to
20267399932SMahesh Ravishankar ///
20367399932SMahesh Ravishankar /// ```mlir
20467399932SMahesh Ravishankar /// #accesses = [
20567399932SMahesh Ravishankar ///   affine_map<(d0, d1) -> (d1)>,
20667399932SMahesh Ravishankar ///   affine_map<(d0, d1) -> (d0)>,
20767399932SMahesh Ravishankar ///   affine_map<(d0, d1) -> (d0, d1)>
20867399932SMahesh Ravishankar /// ]
20967399932SMahesh Ravishankar ///
21067399932SMahesh Ravishankar /// #trait = {
21167399932SMahesh Ravishankar ///   indexing_maps = #accesses,
21267399932SMahesh Ravishankar ///   iterator_types = ["parallel", "parallel"],
21367399932SMahesh Ravishankar ///   library_call = "some_external_fn"
21467399932SMahesh Ravishankar /// }
21567399932SMahesh Ravishankar ///
21667399932SMahesh Ravishankar /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
21767399932SMahesh Ravishankar /// tensor<5x5xf32>
21867399932SMahesh Ravishankar /// {
21967399932SMahesh Ravishankar ///   %0 = linalg.generic #trait %arg0, %arg1 {
22067399932SMahesh Ravishankar ///        ^bb0(%arg2: f32, %arg3: f32):
22167399932SMahesh Ravishankar ///          %3 = arith.addf %arg2, %arg3 : f32
22267399932SMahesh Ravishankar ///          linalg.yield %3 : f32
22367399932SMahesh Ravishankar ///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
22467399932SMahesh Ravishankar ///   return %0 : tensor<5x5xf32>
22567399932SMahesh Ravishankar /// }
2262b0c8546SMaheshRavishankar 
22767399932SMahesh Ravishankar /// Update the index accesses of linalg operations having index semantics.
22867399932SMahesh Ravishankar static void
22967399932SMahesh Ravishankar replaceUnitDimIndexOps(GenericOp genericOp,
23067399932SMahesh Ravishankar                        const llvm::SmallDenseSet<unsigned> &unitDims,
23167399932SMahesh Ravishankar                        RewriterBase &rewriter) {
23267399932SMahesh Ravishankar   for (IndexOp indexOp :
23367399932SMahesh Ravishankar        llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
23467399932SMahesh Ravishankar     OpBuilder::InsertionGuard guard(rewriter);
23567399932SMahesh Ravishankar     rewriter.setInsertionPoint(indexOp);
23667399932SMahesh Ravishankar     if (unitDims.count(indexOp.getDim()) != 0) {
23767399932SMahesh Ravishankar       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
23867399932SMahesh Ravishankar     } else {
23967399932SMahesh Ravishankar       // Update the dimension of the index operation if needed.
24067399932SMahesh Ravishankar       unsigned droppedDims = llvm::count_if(
24167399932SMahesh Ravishankar           unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
24267399932SMahesh Ravishankar       if (droppedDims != 0)
24367399932SMahesh Ravishankar         rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
24467399932SMahesh Ravishankar                                              indexOp.getDim() - droppedDims);
24567399932SMahesh Ravishankar     }
24667399932SMahesh Ravishankar   }
24744485fcdSTres Popp }
24844485fcdSTres Popp 
24967399932SMahesh Ravishankar /// Expand the given `value` so that the type matches the type of `origDest`.
25067399932SMahesh Ravishankar /// The `reassociation` is used when `rankReductionStrategy` is set to
25167399932SMahesh Ravishankar /// `RankReductionStrategy::ReassociativeReshape`.
25267399932SMahesh Ravishankar static Value
25367399932SMahesh Ravishankar expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
25467399932SMahesh Ravishankar             ArrayRef<ReassociationIndices> reassociation,
25567399932SMahesh Ravishankar             ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
256e07149c9SMatthias Springer   // There are no results for memref outputs.
25767399932SMahesh Ravishankar   auto origResultType = cast<RankedTensorType>(origDest.getType());
25867399932SMahesh Ravishankar   if (rankReductionStrategy ==
25967399932SMahesh Ravishankar       ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
260e07149c9SMatthias Springer     unsigned rank = origResultType.getRank();
261e07149c9SMatthias Springer     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
262e07149c9SMatthias Springer     SmallVector<OpFoldResult> sizes =
26367399932SMahesh Ravishankar         tensor::getMixedSizes(rewriter, loc, origDest);
264e07149c9SMatthias Springer     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
265e07149c9SMatthias Springer     return rewriter.createOrFold<tensor::InsertSliceOp>(
26667399932SMahesh Ravishankar         loc, result, origDest, offsets, sizes, strides);
2676c7be417STres Popp   }
2686c7be417STres Popp 
269e07149c9SMatthias Springer   assert(rankReductionStrategy ==
27067399932SMahesh Ravishankar              ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
271e07149c9SMatthias Springer          "unknown rank reduction strategy");
27297069a86SGaurav Shukla   return rewriter
27397069a86SGaurav Shukla       .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
27497069a86SGaurav Shukla       .getResult();
2756c7be417STres Popp }
276e07149c9SMatthias Springer 
27767399932SMahesh Ravishankar /// Collapse the given `value` so that the type matches the type of
27867399932SMahesh Ravishankar /// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
27967399932SMahesh Ravishankar /// set to `RankReductionStrategy::ReassociativeReshape`.
28067399932SMahesh Ravishankar static Value collapseValue(
28167399932SMahesh Ravishankar     RewriterBase &rewriter, Location loc, Value operand,
28267399932SMahesh Ravishankar     ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
28367399932SMahesh Ravishankar     ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
2845550c821STres Popp   if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
28567399932SMahesh Ravishankar     if (rankReductionStrategy ==
28667399932SMahesh Ravishankar         ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
287e07149c9SMatthias Springer       FailureOr<Value> rankReducingExtract =
288e07149c9SMatthias Springer           memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
289e07149c9SMatthias Springer                                                 targetShape);
290e07149c9SMatthias Springer       assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
291e07149c9SMatthias Springer       return *rankReducingExtract;
2926c7be417STres Popp     }
293e07149c9SMatthias Springer 
29467399932SMahesh Ravishankar     assert(
29567399932SMahesh Ravishankar         rankReductionStrategy ==
29667399932SMahesh Ravishankar             ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
297e07149c9SMatthias Springer         "unknown rank reduction strategy");
298e07149c9SMatthias Springer     MemRefLayoutAttrInterface layout;
29967399932SMahesh Ravishankar     auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
30067399932SMahesh Ravishankar                                       layout, memrefType.getMemorySpace());
301e07149c9SMatthias Springer     return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
302e07149c9SMatthias Springer                                                     reassociation);
303e07149c9SMatthias Springer   }
3045550c821STres Popp   if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
30567399932SMahesh Ravishankar     if (rankReductionStrategy ==
30667399932SMahesh Ravishankar         ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
307e07149c9SMatthias Springer       FailureOr<Value> rankReducingExtract =
308e07149c9SMatthias Springer           tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
309e07149c9SMatthias Springer                                                      targetShape);
310e07149c9SMatthias Springer       assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
311e07149c9SMatthias Springer       return *rankReducingExtract;
312e07149c9SMatthias Springer     }
313e07149c9SMatthias Springer 
31467399932SMahesh Ravishankar     assert(
31567399932SMahesh Ravishankar         rankReductionStrategy ==
31667399932SMahesh Ravishankar             ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
317e07149c9SMatthias Springer         "unknown rank reduction strategy");
318e07149c9SMatthias Springer     auto targetType =
319e07149c9SMatthias Springer         RankedTensorType::get(targetShape, tensorType.getElementType());
320e07149c9SMatthias Springer     return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
321e07149c9SMatthias Springer                                                     reassociation);
322e07149c9SMatthias Springer   }
323e07149c9SMatthias Springer   llvm_unreachable("unsupported operand type");
324e07149c9SMatthias Springer }
3256c7be417STres Popp 
32667399932SMahesh Ravishankar /// Compute the modified metadata for an operands of operation
32767399932SMahesh Ravishankar /// whose unit dims are being dropped. Return the new indexing map
32867399932SMahesh Ravishankar /// to use, the shape of the operand in the replacement op
32967399932SMahesh Ravishankar /// and the `reassocation` to use to go from original operand shape
33067399932SMahesh Ravishankar /// to modified operand shape.
33167399932SMahesh Ravishankar struct UnitExtentReplacementInfo {
33267399932SMahesh Ravishankar   AffineMap indexMap;
33367399932SMahesh Ravishankar   SmallVector<ReassociationIndices> reassociation;
33467399932SMahesh Ravishankar   SmallVector<int64_t> targetShape;
33567399932SMahesh Ravishankar };
33667399932SMahesh Ravishankar static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
33767399932SMahesh Ravishankar     MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
33867399932SMahesh Ravishankar     llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
33967399932SMahesh Ravishankar     ArrayRef<AffineExpr> dimReplacements) {
34067399932SMahesh Ravishankar   UnitExtentReplacementInfo info;
34167399932SMahesh Ravishankar   ReassociationIndices reassociationGroup;
34267399932SMahesh Ravishankar   SmallVector<AffineExpr> newIndexExprs;
34367399932SMahesh Ravishankar   AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
34467399932SMahesh Ravishankar   ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
34567399932SMahesh Ravishankar   ArrayRef<AffineExpr> exprs = indexingMap.getResults();
3462b0c8546SMaheshRavishankar 
34767399932SMahesh Ravishankar   auto isUnitDim = [&](unsigned dim) {
3481609f1c2Slong.chen     if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
34967399932SMahesh Ravishankar       unsigned oldPosition = dimExpr.getPosition();
35044861c7aSSayan Saha       return !oldDimsToNewDimsMap.count(oldPosition) &&
35144861c7aSSayan Saha              (operandShape[dim] == 1);
35267399932SMahesh Ravishankar     }
35367399932SMahesh Ravishankar     // Handle the other case where the shape is 1, and is accessed using a
35467399932SMahesh Ravishankar     // constant 0.
35567399932SMahesh Ravishankar     if (operandShape[dim] == 1) {
3561609f1c2Slong.chen       auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
35767399932SMahesh Ravishankar       return constAffineExpr && constAffineExpr.getValue() == 0;
35867399932SMahesh Ravishankar     }
35967399932SMahesh Ravishankar     return false;
36067399932SMahesh Ravishankar   };
36167399932SMahesh Ravishankar 
3622efe88b5SMahesh Ravishankar   unsigned dim = 0;
36367399932SMahesh Ravishankar   while (dim < operandShape.size() && isUnitDim(dim))
36467399932SMahesh Ravishankar     reassociationGroup.push_back(dim++);
36567399932SMahesh Ravishankar   while (dim < operandShape.size()) {
36667399932SMahesh Ravishankar     assert(!isUnitDim(dim) && "expected non unit-extent");
36767399932SMahesh Ravishankar     reassociationGroup.push_back(dim);
36867399932SMahesh Ravishankar     AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
36967399932SMahesh Ravishankar     newIndexExprs.push_back(newExpr);
37067399932SMahesh Ravishankar     info.targetShape.push_back(operandShape[dim]);
37167399932SMahesh Ravishankar     ++dim;
37267399932SMahesh Ravishankar     // Fold all following dimensions that are unit-extent.
37367399932SMahesh Ravishankar     while (dim < operandShape.size() && isUnitDim(dim)) {
37467399932SMahesh Ravishankar       reassociationGroup.push_back(dim++);
37567399932SMahesh Ravishankar     }
37667399932SMahesh Ravishankar     info.reassociation.push_back(reassociationGroup);
37767399932SMahesh Ravishankar     reassociationGroup.clear();
37867399932SMahesh Ravishankar   }
37967399932SMahesh Ravishankar   info.indexMap =
38067399932SMahesh Ravishankar       AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(),
38167399932SMahesh Ravishankar                      newIndexExprs, context);
38267399932SMahesh Ravishankar   return info;
38367399932SMahesh Ravishankar }
38467399932SMahesh Ravishankar 
3854dbaef6dSMaheshRavishankar FailureOr<DropUnitDimsResult>
3864dbaef6dSMaheshRavishankar linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
38767399932SMahesh Ravishankar                      const ControlDropUnitDims &options) {
38867399932SMahesh Ravishankar   SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
38967399932SMahesh Ravishankar   if (indexingMaps.empty())
39067399932SMahesh Ravishankar     return failure();
39167399932SMahesh Ravishankar 
39267399932SMahesh Ravishankar   // 1. Check if any of the iteration dimensions are unit-trip count. They will
39367399932SMahesh Ravishankar   //    end up being unit-trip count if they are used to index into a unit-dim
39467399932SMahesh Ravishankar   //    tensor/memref.
39506514c55SIan Wood   AffineMap invertedMap =
39606514c55SIan Wood       inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
39767399932SMahesh Ravishankar   if (!invertedMap) {
39867399932SMahesh Ravishankar     return rewriter.notifyMatchFailure(genericOp,
39967399932SMahesh Ravishankar                                        "invalid indexing maps for operation");
40067399932SMahesh Ravishankar   }
40167399932SMahesh Ravishankar   SmallVector<int64_t> dims = genericOp.getStaticShape();
40267399932SMahesh Ravishankar 
40367399932SMahesh Ravishankar   // 1a. Get the allowed list of dimensions to drop from the `options`.
40467399932SMahesh Ravishankar   SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
40567399932SMahesh Ravishankar   if (allowedUnitDims.empty()) {
40667399932SMahesh Ravishankar     return rewriter.notifyMatchFailure(
40767399932SMahesh Ravishankar         genericOp, "control function returns no allowed unit dims to prune");
40867399932SMahesh Ravishankar   }
40967399932SMahesh Ravishankar   llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
41067399932SMahesh Ravishankar                                                allowedUnitDims.end());
41167399932SMahesh Ravishankar   llvm::SmallDenseSet<unsigned> unitDims;
41267399932SMahesh Ravishankar   for (const auto &expr : enumerate(invertedMap.getResults())) {
4131609f1c2Slong.chen     if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
41467399932SMahesh Ravishankar       if (dims[dimExpr.getPosition()] == 1 &&
41567399932SMahesh Ravishankar           unitDimsFilter.count(expr.index()))
41667399932SMahesh Ravishankar         unitDims.insert(expr.index());
41767399932SMahesh Ravishankar     }
41867399932SMahesh Ravishankar   }
41967399932SMahesh Ravishankar 
42067399932SMahesh Ravishankar   // 2. Compute the iterator types of the modified op by dropping the one-trip
42167399932SMahesh Ravishankar   //    count loops.
42267399932SMahesh Ravishankar   SmallVector<utils::IteratorType> newIteratorTypes;
42367399932SMahesh Ravishankar   llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
42467399932SMahesh Ravishankar   SmallVector<AffineExpr> dimReplacements;
42567399932SMahesh Ravishankar   unsigned newDims = 0;
42667399932SMahesh Ravishankar   for (auto [index, attr] :
42767399932SMahesh Ravishankar        llvm::enumerate(genericOp.getIteratorTypesArray())) {
42867399932SMahesh Ravishankar     if (unitDims.count(index)) {
42967399932SMahesh Ravishankar       dimReplacements.push_back(
43067399932SMahesh Ravishankar           getAffineConstantExpr(0, rewriter.getContext()));
43167399932SMahesh Ravishankar     } else {
43267399932SMahesh Ravishankar       newIteratorTypes.push_back(attr);
43367399932SMahesh Ravishankar       oldDimToNewDimMap[index] = newDims;
43467399932SMahesh Ravishankar       dimReplacements.push_back(
43567399932SMahesh Ravishankar           getAffineDimExpr(newDims, rewriter.getContext()));
43667399932SMahesh Ravishankar       newDims++;
43767399932SMahesh Ravishankar     }
43867399932SMahesh Ravishankar   }
43967399932SMahesh Ravishankar 
44067399932SMahesh Ravishankar   // 3. For each of the operands, find the
44167399932SMahesh Ravishankar   //    - modified affine map to use.
44267399932SMahesh Ravishankar   //    - shape of the operands after the unit-dims are dropped.
44367399932SMahesh Ravishankar   //    - the reassociation indices used to convert from the original
44467399932SMahesh Ravishankar   //      operand type to modified operand (needed only when using reshapes
44567399932SMahesh Ravishankar   //      for rank reduction strategy)
44667399932SMahesh Ravishankar   // Note that the indexing maps might need changing even if there are no
44767399932SMahesh Ravishankar   // unit dimensions that are dropped to handle cases where `0` is used to
44867399932SMahesh Ravishankar   // access a unit-extent tensor. Consider moving this out of this specific
44967399932SMahesh Ravishankar   // transformation as a stand-alone transformation. Kept here right now due
45067399932SMahesh Ravishankar   // to legacy.
451f6b4e081STobias Gysi   SmallVector<AffineMap> newIndexingMaps;
452e07149c9SMatthias Springer   SmallVector<SmallVector<ReassociationIndices>> reassociations;
453e07149c9SMatthias Springer   SmallVector<SmallVector<int64_t>> targetShapes;
454e07149c9SMatthias Springer   SmallVector<bool> collapsed;
45567399932SMahesh Ravishankar   auto hasCollapsibleType = [](OpOperand &operand) {
45667399932SMahesh Ravishankar     Type operandType = operand.get().getType();
45767399932SMahesh Ravishankar     if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
45867399932SMahesh Ravishankar       return memrefOperandType.getLayout().isIdentity();
459f19f2139SMehdi Amini     }
460f19f2139SMehdi Amini     if (auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
46167399932SMahesh Ravishankar       return tensorOperandType.getEncoding() == nullptr;
46244485fcdSTres Popp     }
46367399932SMahesh Ravishankar     return false;
46467399932SMahesh Ravishankar   };
46567399932SMahesh Ravishankar   for (OpOperand &opOperand : genericOp->getOpOperands()) {
46667399932SMahesh Ravishankar     auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
46767399932SMahesh Ravishankar     ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
46867399932SMahesh Ravishankar     if (!hasCollapsibleType(opOperand)) {
46967399932SMahesh Ravishankar       AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
47067399932SMahesh Ravishankar           dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
47167399932SMahesh Ravishankar       newIndexingMaps.push_back(newIndexingMap);
47267399932SMahesh Ravishankar       targetShapes.push_back(llvm::to_vector(shape));
47367399932SMahesh Ravishankar       collapsed.push_back(false);
47467399932SMahesh Ravishankar       reassociations.push_back({});
47567399932SMahesh Ravishankar       continue;
47667399932SMahesh Ravishankar     }
47767399932SMahesh Ravishankar     auto replacementInfo = dropUnitExtentFromOperandMetadata(
47867399932SMahesh Ravishankar         rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
47967399932SMahesh Ravishankar         dimReplacements);
48067399932SMahesh Ravishankar     reassociations.push_back(replacementInfo.reassociation);
48167399932SMahesh Ravishankar     newIndexingMaps.push_back(replacementInfo.indexMap);
48267399932SMahesh Ravishankar     targetShapes.push_back(replacementInfo.targetShape);
48367399932SMahesh Ravishankar     collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
48467399932SMahesh Ravishankar                           indexingMap.getNumResults()));
4852b0c8546SMaheshRavishankar   }
4862b0c8546SMaheshRavishankar 
487e07149c9SMatthias Springer   // Abort if the indexing maps of the result operation are not invertible
488e07149c9SMatthias Springer   // (i.e. not legal) or if no dimension was reduced.
48967399932SMahesh Ravishankar   if (newIndexingMaps == indexingMaps ||
49006514c55SIan Wood       !inversePermutation(
49106514c55SIan Wood           concatAffineMaps(newIndexingMaps, rewriter.getContext())))
4922b0c8546SMaheshRavishankar     return failure();
4932b0c8546SMaheshRavishankar 
49467399932SMahesh Ravishankar   Location loc = genericOp.getLoc();
49567399932SMahesh Ravishankar   // 4. For each of the operands, collapse the operand to convert
49667399932SMahesh Ravishankar   //    from original shape to shape in the modified operation if needed,
49767399932SMahesh Ravishankar   //    either through use of reshapes or rank-reducing slices as
49867399932SMahesh Ravishankar   //    specified in `options`.
499e07149c9SMatthias Springer   SmallVector<Value> newOperands;
500e07149c9SMatthias Springer   for (OpOperand &opOperand : genericOp->getOpOperands()) {
501e07149c9SMatthias Springer     int64_t idx = opOperand.getOperandNumber();
502e07149c9SMatthias Springer     if (!collapsed[idx]) {
503e07149c9SMatthias Springer       newOperands.push_back(opOperand.get());
504e07149c9SMatthias Springer       continue;
5052b0c8546SMaheshRavishankar     }
50667399932SMahesh Ravishankar     newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(),
50767399932SMahesh Ravishankar                                         targetShapes[idx], reassociations[idx],
50867399932SMahesh Ravishankar                                         options.rankReductionStrategy));
509e07149c9SMatthias Springer   }
5102b0c8546SMaheshRavishankar 
51167399932SMahesh Ravishankar   // 5. Create the `linalg.generic` operation with the new operands,
51267399932SMahesh Ravishankar   //    indexing maps, iterator types and result types.
513e07149c9SMatthias Springer   ArrayRef<Value> newInputs =
514e07149c9SMatthias Springer       ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
515e07149c9SMatthias Springer   ArrayRef<Value> newOutputs =
516e07149c9SMatthias Springer       ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
517e07149c9SMatthias Springer   SmallVector<Type> resultTypes;
518f358c372STobias Gysi   resultTypes.reserve(genericOp.getNumResults());
519f358c372STobias Gysi   for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
520e07149c9SMatthias Springer     resultTypes.push_back(newOutputs[i].getType());
52167399932SMahesh Ravishankar   GenericOp replacementOp =
52267399932SMahesh Ravishankar       rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
52367399932SMahesh Ravishankar                                  newIndexingMaps, newIteratorTypes);
52467399932SMahesh Ravishankar   rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
525d3b3f765SJacques Pienaar                               replacementOp.getRegion().begin());
52667399932SMahesh Ravishankar   // 5a. Replace `linalg.index` operations that refer to the dropped unit
52767399932SMahesh Ravishankar   //     dimensions.
52867399932SMahesh Ravishankar   replaceUnitDimIndexOps(replacementOp, unitDims, rewriter);
5292b0c8546SMaheshRavishankar 
53067399932SMahesh Ravishankar   // 6. If any result type changes, insert a reshape/slice to convert from the
531f3b4e47bSLongsheng Mou   //    original type to the new type.
532e07149c9SMatthias Springer   SmallVector<Value> resultReplacements;
53367399932SMahesh Ravishankar   for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) {
53467399932SMahesh Ravishankar     unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
53567399932SMahesh Ravishankar     Value origDest = genericOp.getDpsInitOperand(index)->get();
53667399932SMahesh Ravishankar     if (!collapsed[opOperandIndex]) {
53767399932SMahesh Ravishankar       resultReplacements.push_back(result);
538e07149c9SMatthias Springer       continue;
5392b0c8546SMaheshRavishankar     }
54097069a86SGaurav Shukla     Value expandedValue = expandValue(rewriter, loc, result, origDest,
54167399932SMahesh Ravishankar                                       reassociations[opOperandIndex],
54297069a86SGaurav Shukla                                       options.rankReductionStrategy);
54397069a86SGaurav Shukla     resultReplacements.push_back(expandedValue);
544e07149c9SMatthias Springer   }
545e07149c9SMatthias Springer 
5464dbaef6dSMaheshRavishankar   return DropUnitDimsResult{replacementOp, resultReplacements};
5472b0c8546SMaheshRavishankar }
548e07149c9SMatthias Springer 
54967399932SMahesh Ravishankar namespace {
55067399932SMahesh Ravishankar struct DropUnitDims : public OpRewritePattern<GenericOp> {
55167399932SMahesh Ravishankar   DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
55267399932SMahesh Ravishankar                PatternBenefit benefit = 1)
55367399932SMahesh Ravishankar       : OpRewritePattern(context, benefit), options(std::move(options)) {}
55467399932SMahesh Ravishankar 
55567399932SMahesh Ravishankar   LogicalResult matchAndRewrite(GenericOp genericOp,
55667399932SMahesh Ravishankar                                 PatternRewriter &rewriter) const override {
5574dbaef6dSMaheshRavishankar     FailureOr<DropUnitDimsResult> result =
5584dbaef6dSMaheshRavishankar         dropUnitDims(rewriter, genericOp, options);
5594dbaef6dSMaheshRavishankar     if (failed(result)) {
5604dbaef6dSMaheshRavishankar       return failure();
5614dbaef6dSMaheshRavishankar     }
5624dbaef6dSMaheshRavishankar     rewriter.replaceOp(genericOp, result->replacements);
5634dbaef6dSMaheshRavishankar     return success();
56467399932SMahesh Ravishankar   }
56567399932SMahesh Ravishankar 
566e07149c9SMatthias Springer private:
56767399932SMahesh Ravishankar   ControlDropUnitDims options;
5682b0c8546SMaheshRavishankar };
569fd15e2b8SMaheshRavishankar } // namespace
570f0a2fe7fSMaheshRavishankar 
57160e562d1SQuinn Dawkins //===---------------------------------------------------------------------===//
57260e562d1SQuinn Dawkins // Drop dimensions that are unit-extents within tensor operations.
57360e562d1SQuinn Dawkins //===---------------------------------------------------------------------===//
57460e562d1SQuinn Dawkins 
57560e562d1SQuinn Dawkins namespace {
57660e562d1SQuinn Dawkins struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
57760e562d1SQuinn Dawkins   DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
57860e562d1SQuinn Dawkins                   PatternBenefit benefit = 1)
57960e562d1SQuinn Dawkins       : OpRewritePattern(context, benefit), options(std::move(options)) {}
58060e562d1SQuinn Dawkins 
58160e562d1SQuinn Dawkins   LogicalResult matchAndRewrite(tensor::PadOp padOp,
58260e562d1SQuinn Dawkins                                 PatternRewriter &rewriter) const override {
58360e562d1SQuinn Dawkins     // 1a. Get the allowed list of dimensions to drop from the `options`.
58460e562d1SQuinn Dawkins     SmallVector<unsigned> allowedUnitDims = options.controlFn(padOp);
58560e562d1SQuinn Dawkins     if (allowedUnitDims.empty()) {
58660e562d1SQuinn Dawkins       return rewriter.notifyMatchFailure(
58760e562d1SQuinn Dawkins           padOp, "control function returns no allowed unit dims to prune");
58860e562d1SQuinn Dawkins     }
58960e562d1SQuinn Dawkins 
59060e562d1SQuinn Dawkins     if (padOp.getSourceType().getEncoding()) {
59160e562d1SQuinn Dawkins       return rewriter.notifyMatchFailure(
59260e562d1SQuinn Dawkins           padOp, "cannot collapse dims of tensor with encoding");
59360e562d1SQuinn Dawkins     }
59460e562d1SQuinn Dawkins 
59560e562d1SQuinn Dawkins     // Fail for non-constant padding values. The body of the pad could
59660e562d1SQuinn Dawkins     // depend on the padding indices and/or properties of the padded
59760e562d1SQuinn Dawkins     // tensor so for now we fail.
59860e562d1SQuinn Dawkins     // TODO: Support non-constant padding values.
59960e562d1SQuinn Dawkins     Value paddingVal = padOp.getConstantPaddingValue();
60060e562d1SQuinn Dawkins     if (!paddingVal) {
60160e562d1SQuinn Dawkins       return rewriter.notifyMatchFailure(
60260e562d1SQuinn Dawkins           padOp, "unimplemented: non-constant padding value");
60360e562d1SQuinn Dawkins     }
60460e562d1SQuinn Dawkins 
60560e562d1SQuinn Dawkins     ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
60660e562d1SQuinn Dawkins     int64_t padRank = sourceShape.size();
60760e562d1SQuinn Dawkins 
60860e562d1SQuinn Dawkins     auto isStaticZero = [](OpFoldResult f) {
60960e562d1SQuinn Dawkins       std::optional<int64_t> maybeInt = getConstantIntValue(f);
61060e562d1SQuinn Dawkins       return maybeInt && *maybeInt == 0;
61160e562d1SQuinn Dawkins     };
61260e562d1SQuinn Dawkins 
61360e562d1SQuinn Dawkins     llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
61460e562d1SQuinn Dawkins                                                  allowedUnitDims.end());
61560e562d1SQuinn Dawkins     llvm::SmallDenseSet<unsigned> unitDims;
61660e562d1SQuinn Dawkins     SmallVector<int64_t> newShape;
61760e562d1SQuinn Dawkins     SmallVector<OpFoldResult> newLowPad;
61860e562d1SQuinn Dawkins     SmallVector<OpFoldResult> newHighPad;
61960e562d1SQuinn Dawkins     for (const auto [dim, size, low, high] :
62060e562d1SQuinn Dawkins          zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
62160e562d1SQuinn Dawkins                    padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
62260e562d1SQuinn Dawkins       if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
62360e562d1SQuinn Dawkins           isStaticZero(high)) {
62460e562d1SQuinn Dawkins         unitDims.insert(dim);
62560e562d1SQuinn Dawkins       } else {
62660e562d1SQuinn Dawkins         newShape.push_back(size);
62760e562d1SQuinn Dawkins         newLowPad.push_back(low);
62860e562d1SQuinn Dawkins         newHighPad.push_back(high);
62960e562d1SQuinn Dawkins       }
63060e562d1SQuinn Dawkins     }
63160e562d1SQuinn Dawkins 
63260e562d1SQuinn Dawkins     if (unitDims.empty()) {
63360e562d1SQuinn Dawkins       return rewriter.notifyMatchFailure(padOp, "no unit dims to collapse");
63460e562d1SQuinn Dawkins     }
63560e562d1SQuinn Dawkins 
63660e562d1SQuinn Dawkins     ReassociationIndices reassociationGroup;
63760e562d1SQuinn Dawkins     SmallVector<ReassociationIndices> reassociationMap;
63860e562d1SQuinn Dawkins     int64_t dim = 0;
63960e562d1SQuinn Dawkins     while (dim < padRank && unitDims.contains(dim))
64060e562d1SQuinn Dawkins       reassociationGroup.push_back(dim++);
64160e562d1SQuinn Dawkins     while (dim < padRank) {
64260e562d1SQuinn Dawkins       assert(!unitDims.contains(dim) && "expected non unit-extent");
64360e562d1SQuinn Dawkins       reassociationGroup.push_back(dim);
64460e562d1SQuinn Dawkins       dim++;
64560e562d1SQuinn Dawkins       // Fold all following dimensions that are unit-extent.
64660e562d1SQuinn Dawkins       while (dim < padRank && unitDims.contains(dim))
64760e562d1SQuinn Dawkins         reassociationGroup.push_back(dim++);
64860e562d1SQuinn Dawkins       reassociationMap.push_back(reassociationGroup);
64960e562d1SQuinn Dawkins       reassociationGroup.clear();
65060e562d1SQuinn Dawkins     }
65160e562d1SQuinn Dawkins 
65260e562d1SQuinn Dawkins     Value collapsedSource =
65360e562d1SQuinn Dawkins         collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
65460e562d1SQuinn Dawkins                       reassociationMap, options.rankReductionStrategy);
65560e562d1SQuinn Dawkins 
65660e562d1SQuinn Dawkins     auto newPadOp = rewriter.create<tensor::PadOp>(
65760e562d1SQuinn Dawkins         padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad,
65860e562d1SQuinn Dawkins         newHighPad, paddingVal, padOp.getNofold());
65960e562d1SQuinn Dawkins 
66060e562d1SQuinn Dawkins     Value dest = padOp.getResult();
66160e562d1SQuinn Dawkins     if (options.rankReductionStrategy ==
66260e562d1SQuinn Dawkins         ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
66360e562d1SQuinn Dawkins       SmallVector<OpFoldResult> expandedSizes;
66460e562d1SQuinn Dawkins       int64_t numUnitDims = 0;
66560e562d1SQuinn Dawkins       for (auto dim : llvm::seq(static_cast<int64_t>(0), padRank)) {
66660e562d1SQuinn Dawkins         if (unitDims.contains(dim)) {
66760e562d1SQuinn Dawkins           expandedSizes.push_back(rewriter.getIndexAttr(1));
66860e562d1SQuinn Dawkins           numUnitDims++;
66960e562d1SQuinn Dawkins           continue;
67060e562d1SQuinn Dawkins         }
67160e562d1SQuinn Dawkins         expandedSizes.push_back(tensor::getMixedSize(
67260e562d1SQuinn Dawkins             rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
67360e562d1SQuinn Dawkins       }
67460e562d1SQuinn Dawkins       dest = rewriter.create<tensor::EmptyOp>(
67560e562d1SQuinn Dawkins           padOp.getLoc(), expandedSizes,
67660e562d1SQuinn Dawkins           padOp.getResultType().getElementType());
67760e562d1SQuinn Dawkins     }
67860e562d1SQuinn Dawkins 
67960e562d1SQuinn Dawkins     Value expandedValue =
68060e562d1SQuinn Dawkins         expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
68160e562d1SQuinn Dawkins                     reassociationMap, options.rankReductionStrategy);
68260e562d1SQuinn Dawkins     rewriter.replaceOp(padOp, expandedValue);
68360e562d1SQuinn Dawkins     return success();
68460e562d1SQuinn Dawkins   }
68560e562d1SQuinn Dawkins 
68660e562d1SQuinn Dawkins private:
68760e562d1SQuinn Dawkins   ControlDropUnitDims options;
68860e562d1SQuinn Dawkins };
68960e562d1SQuinn Dawkins } // namespace
69060e562d1SQuinn Dawkins 
691fd15e2b8SMaheshRavishankar namespace {
692060208b4SMatthias Springer /// Convert `extract_slice` operations to rank-reduced versions.
693df5c981bSNicolas Vasilache struct RankReducedExtractSliceOp
694060208b4SMatthias Springer     : public OpRewritePattern<tensor::ExtractSliceOp> {
695060208b4SMatthias Springer   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
696f0a2fe7fSMaheshRavishankar 
697060208b4SMatthias Springer   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
698f0a2fe7fSMaheshRavishankar                                 PatternRewriter &rewriter) const override {
699060208b4SMatthias Springer     RankedTensorType resultType = sliceOp.getType();
70082ab0f7fSQuinn Dawkins     SmallVector<OpFoldResult> targetShape;
70182ab0f7fSQuinn Dawkins     for (auto size : resultType.getShape())
70282ab0f7fSQuinn Dawkins       targetShape.push_back(rewriter.getIndexAttr(size));
70382ab0f7fSQuinn Dawkins     auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
704fd15e2b8SMaheshRavishankar     if (!reassociation ||
705fd15e2b8SMaheshRavishankar         reassociation->size() == static_cast<size_t>(resultType.getRank()))
706f0a2fe7fSMaheshRavishankar       return failure();
70782ab0f7fSQuinn Dawkins 
70882ab0f7fSQuinn Dawkins     SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
70982ab0f7fSQuinn Dawkins     SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
71082ab0f7fSQuinn Dawkins     SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
7115550c821STres Popp     auto rankReducedType = cast<RankedTensorType>(
712741f8f2bSNicolas Vasilache         tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
713741f8f2bSNicolas Vasilache             reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
7145550c821STres Popp             strides));
715f0a2fe7fSMaheshRavishankar 
716060208b4SMatthias Springer     Location loc = sliceOp.getLoc();
717060208b4SMatthias Springer     Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
71804235d07SJacques Pienaar         loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
719b618880eSAlexander Belyaev     rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
720b618880eSAlexander Belyaev         sliceOp, resultType, newSlice, *reassociation);
721f0a2fe7fSMaheshRavishankar     return success();
722f0a2fe7fSMaheshRavishankar   }
723f0a2fe7fSMaheshRavishankar };
724f0a2fe7fSMaheshRavishankar 
725060208b4SMatthias Springer /// Convert `insert_slice` operations to rank-reduced versions.
726df5c981bSNicolas Vasilache /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
727df5c981bSNicolas Vasilache template <typename InsertOpTy>
728df5c981bSNicolas Vasilache struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
729df5c981bSNicolas Vasilache   using OpRewritePattern<InsertOpTy>::OpRewritePattern;
730fd15e2b8SMaheshRavishankar 
731df5c981bSNicolas Vasilache   LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
732fd15e2b8SMaheshRavishankar                                 PatternRewriter &rewriter) const override {
733df5c981bSNicolas Vasilache     RankedTensorType sourceType = insertSliceOp.getSourceType();
73482ab0f7fSQuinn Dawkins     SmallVector<OpFoldResult> targetShape;
73582ab0f7fSQuinn Dawkins     for (auto size : sourceType.getShape())
73682ab0f7fSQuinn Dawkins       targetShape.push_back(rewriter.getIndexAttr(size));
73782ab0f7fSQuinn Dawkins     auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
738fd15e2b8SMaheshRavishankar     if (!reassociation ||
739fd15e2b8SMaheshRavishankar         reassociation->size() == static_cast<size_t>(sourceType.getRank()))
740fd15e2b8SMaheshRavishankar       return failure();
74182ab0f7fSQuinn Dawkins 
742df5c981bSNicolas Vasilache     Location loc = insertSliceOp.getLoc();
743df5c981bSNicolas Vasilache     tensor::CollapseShapeOp reshapedSource;
744df5c981bSNicolas Vasilache     {
745df5c981bSNicolas Vasilache       OpBuilder::InsertionGuard g(rewriter);
74667399932SMahesh Ravishankar       // The only difference between InsertSliceOp and ParallelInsertSliceOp
74767399932SMahesh Ravishankar       // is the insertion point is just before the ParallelCombiningOp in the
748df5c981bSNicolas Vasilache       // parallel case.
749df5c981bSNicolas Vasilache       if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
750df5c981bSNicolas Vasilache         rewriter.setInsertionPoint(insertSliceOp->getParentOp());
751df5c981bSNicolas Vasilache       reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
752df5c981bSNicolas Vasilache           loc, insertSliceOp.getSource(), *reassociation);
753df5c981bSNicolas Vasilache     }
754df5c981bSNicolas Vasilache     rewriter.replaceOpWithNewOp<InsertOpTy>(
755df5c981bSNicolas Vasilache         insertSliceOp, reshapedSource, insertSliceOp.getDest(),
756df5c981bSNicolas Vasilache         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
757df5c981bSNicolas Vasilache         insertSliceOp.getMixedStrides());
758fd15e2b8SMaheshRavishankar     return success();
759fd15e2b8SMaheshRavishankar   }
760fd15e2b8SMaheshRavishankar };
761b62f9f44SMaheshRavishankar } // namespace
762b62f9f44SMaheshRavishankar 
7632b0c8546SMaheshRavishankar /// Patterns that are used to canonicalize the use of unit-extent dims for
7642b0c8546SMaheshRavishankar /// broadcasting.
76567399932SMahesh Ravishankar static void
76667399932SMahesh Ravishankar populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
76767399932SMahesh Ravishankar                                               ControlDropUnitDims &options) {
7683a506b31SChris Lattner   auto *context = patterns.getContext();
76967399932SMahesh Ravishankar   patterns.add<DropUnitDims>(context, options);
77060e562d1SQuinn Dawkins   patterns.add<DropPadUnitDims>(context, options);
771e07149c9SMatthias Springer   // TODO: Patterns unrelated to unit dim folding should be factored out.
77267399932SMahesh Ravishankar   patterns.add<RankReducedExtractSliceOp,
773df5c981bSNicolas Vasilache                RankReducedInsertSliceOp<tensor::InsertSliceOp>,
774d2b070d3SMatthias Springer                RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
775d2b070d3SMatthias Springer       context);
776b618880eSAlexander Belyaev   linalg::FillOp::getCanonicalizationPatterns(patterns, context);
777b618880eSAlexander Belyaev   tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
77881ca5aa4SMatthias Springer   tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
779b618880eSAlexander Belyaev   tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
780f6fb0a4fSAlexander Belyaev   tensor::populateFoldTensorEmptyPatterns(patterns);
7814abccd39SMatthias Springer   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
7829b16d9d2SHanhan Wang   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
7832b0c8546SMaheshRavishankar }
7842b0c8546SMaheshRavishankar 
78567399932SMahesh Ravishankar static void
78667399932SMahesh Ravishankar populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
78767399932SMahesh Ravishankar                                             ControlDropUnitDims &options) {
788e07149c9SMatthias Springer   auto *context = patterns.getContext();
78967399932SMahesh Ravishankar   patterns.add<DropUnitDims>(context, options);
79060e562d1SQuinn Dawkins   patterns.add<DropPadUnitDims>(context, options);
7916b76c4eaSMatthias Springer   // TODO: Patterns unrelated to unit dim folding should be factored out.
7926b76c4eaSMatthias Springer   linalg::FillOp::getCanonicalizationPatterns(patterns, context);
7936b76c4eaSMatthias Springer   tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
7946b76c4eaSMatthias Springer   tensor::populateFoldTensorEmptyPatterns(patterns);
7954abccd39SMatthias Springer   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
7966b76c4eaSMatthias Springer   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
797e07149c9SMatthias Springer }
798e07149c9SMatthias Springer 
79967399932SMahesh Ravishankar void mlir::linalg::populateFoldUnitExtentDimsPatterns(
80067399932SMahesh Ravishankar     RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {
80167399932SMahesh Ravishankar   if (options.rankReductionStrategy ==
80267399932SMahesh Ravishankar       linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
80367399932SMahesh Ravishankar     populateFoldUnitExtentDimsViaSlicesPatterns(patterns, options);
80467399932SMahesh Ravishankar   } else if (options.rankReductionStrategy ==
80567399932SMahesh Ravishankar              linalg::ControlDropUnitDims::RankReductionStrategy::
80667399932SMahesh Ravishankar                  ReassociativeReshape) {
80767399932SMahesh Ravishankar     populateFoldUnitExtentDimsViaReshapesPatterns(patterns, options);
80867399932SMahesh Ravishankar   }
80967399932SMahesh Ravishankar }
81067399932SMahesh Ravishankar 
811d2b070d3SMatthias Springer void mlir::linalg::populateMoveInitOperandsToInputPattern(
812d2b070d3SMatthias Springer     RewritePatternSet &patterns) {
813d2b070d3SMatthias Springer   patterns.add<MoveInitOperandsToInput>(patterns.getContext());
814d2b070d3SMatthias Springer }
815d2b070d3SMatthias Springer 
8162b0c8546SMaheshRavishankar namespace {
8172b0c8546SMaheshRavishankar /// Pass that removes unit-extent dims within generic ops.
8182b0c8546SMaheshRavishankar struct LinalgFoldUnitExtentDimsPass
8191e98d488SQuinn Dawkins     : public impl::LinalgFoldUnitExtentDimsPassBase<
8201e98d488SQuinn Dawkins           LinalgFoldUnitExtentDimsPass> {
8211e98d488SQuinn Dawkins   using impl::LinalgFoldUnitExtentDimsPassBase<
8221e98d488SQuinn Dawkins       LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
823c10995a8SStella Laurenzo   void runOnOperation() override {
824c10995a8SStella Laurenzo     Operation *op = getOperation();
825c10995a8SStella Laurenzo     MLIRContext *context = op->getContext();
826dc4e913bSChris Lattner     RewritePatternSet patterns(context);
82767399932SMahesh Ravishankar     ControlDropUnitDims options;
82867399932SMahesh Ravishankar     if (useRankReducingSlices) {
82967399932SMahesh Ravishankar       options.rankReductionStrategy = linalg::ControlDropUnitDims::
83067399932SMahesh Ravishankar           RankReductionStrategy::ExtractInsertSlice;
831e07149c9SMatthias Springer     }
83267399932SMahesh Ravishankar     linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
83367399932SMahesh Ravishankar     populateMoveInitOperandsToInputPattern(patterns);
834*09dfc571SJacques Pienaar     (void)applyPatternsGreedily(op, std::move(patterns));
8352b0c8546SMaheshRavishankar   }
8362b0c8546SMaheshRavishankar };
837431213c9Ssrcarroll 
8382b0c8546SMaheshRavishankar } // namespace
839431213c9Ssrcarroll 
840431213c9Ssrcarroll namespace {
841431213c9Ssrcarroll 
842431213c9Ssrcarroll /// Returns reassociation indices for collapsing/expanding a
843431213c9Ssrcarroll /// tensor of rank `rank` at position `pos`.
844431213c9Ssrcarroll static SmallVector<ReassociationIndices>
845431213c9Ssrcarroll getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
846431213c9Ssrcarroll   SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
847431213c9Ssrcarroll   bool lastDim = pos == rank - 1;
848431213c9Ssrcarroll   if (rank > 2) {
849431213c9Ssrcarroll     for (int64_t i = 0; i < rank - 1; i++) {
850431213c9Ssrcarroll       if (i == pos || (lastDim && i == pos - 1))
851431213c9Ssrcarroll         reassociation[i] = ReassociationIndices{i, i + 1};
852431213c9Ssrcarroll       else if (i < pos)
853431213c9Ssrcarroll         reassociation[i] = ReassociationIndices{i};
854431213c9Ssrcarroll       else
855431213c9Ssrcarroll         reassociation[i] = ReassociationIndices{i + 1};
856431213c9Ssrcarroll     }
857431213c9Ssrcarroll   }
858431213c9Ssrcarroll   return reassociation;
859431213c9Ssrcarroll }
860431213c9Ssrcarroll 
861431213c9Ssrcarroll /// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
862431213c9Ssrcarroll /// If `pos < 0`, then don't collapse.
863431213c9Ssrcarroll static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
864431213c9Ssrcarroll                                     int64_t pos) {
865431213c9Ssrcarroll   if (pos < 0)
866431213c9Ssrcarroll     return val;
867431213c9Ssrcarroll   auto valType = cast<ShapedType>(val.getType());
868431213c9Ssrcarroll   SmallVector<int64_t> collapsedShape(valType.getShape());
869431213c9Ssrcarroll   collapsedShape.erase(collapsedShape.begin() + pos);
870431213c9Ssrcarroll   return collapseValue(
871431213c9Ssrcarroll       rewriter, val.getLoc(), val, collapsedShape,
872431213c9Ssrcarroll       getReassociationForReshapeAtDim(valType.getRank(), pos),
873431213c9Ssrcarroll       ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
874431213c9Ssrcarroll }
875431213c9Ssrcarroll 
876431213c9Ssrcarroll /// Base class for all rank reduction patterns for contraction ops
877431213c9Ssrcarroll /// with unit dimensions.  All patterns should convert one named op
878431213c9Ssrcarroll /// to another named op.  Intended to reduce only one iteration space dim
879431213c9Ssrcarroll /// at a time.
880431213c9Ssrcarroll /// Reducing multiple dims will happen with recusive application of
881431213c9Ssrcarroll /// pattern rewrites.
882431213c9Ssrcarroll template <typename FromOpTy, typename ToOpTy>
883431213c9Ssrcarroll struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
884431213c9Ssrcarroll   using OpRewritePattern<FromOpTy>::OpRewritePattern;
885431213c9Ssrcarroll 
886431213c9Ssrcarroll   /// Collapse all collapsable operands.
887431213c9Ssrcarroll   SmallVector<Value>
888431213c9Ssrcarroll   collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
889431213c9Ssrcarroll                    ArrayRef<int64_t> operandCollapseDims) const {
890431213c9Ssrcarroll     assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
891431213c9Ssrcarroll            "expected 3 operands and dims");
892431213c9Ssrcarroll     return llvm::map_to_vector(
893431213c9Ssrcarroll         llvm::zip(operands, operandCollapseDims), [&](auto pair) {
894431213c9Ssrcarroll           return collapseSingletonDimAt(rewriter, std::get<0>(pair),
895431213c9Ssrcarroll                                         std::get<1>(pair));
896431213c9Ssrcarroll         });
897431213c9Ssrcarroll   }
898431213c9Ssrcarroll 
899431213c9Ssrcarroll   /// Expand result tensor.
900431213c9Ssrcarroll   Value expandResult(PatternRewriter &rewriter, Value result,
901431213c9Ssrcarroll                      RankedTensorType expandedType, int64_t dim) const {
902431213c9Ssrcarroll     return rewriter.create<tensor::ExpandShapeOp>(
903431213c9Ssrcarroll         result.getLoc(), expandedType, result,
904431213c9Ssrcarroll         getReassociationForReshapeAtDim(expandedType.getRank(), dim));
905431213c9Ssrcarroll   }
906431213c9Ssrcarroll 
907431213c9Ssrcarroll   LogicalResult matchAndRewrite(FromOpTy contractionOp,
908431213c9Ssrcarroll                                 PatternRewriter &rewriter) const override {
909431213c9Ssrcarroll 
910431213c9Ssrcarroll     auto loc = contractionOp.getLoc();
911431213c9Ssrcarroll     auto inputs = contractionOp.getDpsInputs();
912431213c9Ssrcarroll     auto inits = contractionOp.getDpsInits();
913431213c9Ssrcarroll     if (inputs.size() != 2 || inits.size() != 1)
914431213c9Ssrcarroll       return rewriter.notifyMatchFailure(contractionOp,
915431213c9Ssrcarroll                                          "expected 2 inputs and 1 init");
916431213c9Ssrcarroll     auto lhs = inputs[0];
917431213c9Ssrcarroll     auto rhs = inputs[1];
918431213c9Ssrcarroll     auto init = inits[0];
919431213c9Ssrcarroll     SmallVector<Value> operands{lhs, rhs, init};
920431213c9Ssrcarroll 
921431213c9Ssrcarroll     SmallVector<int64_t> operandUnitDims;
922431213c9Ssrcarroll     if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
923431213c9Ssrcarroll       return rewriter.notifyMatchFailure(contractionOp,
924431213c9Ssrcarroll                                          "no reducable dims found");
925431213c9Ssrcarroll 
926431213c9Ssrcarroll     SmallVector<Value> collapsedOperands =
927431213c9Ssrcarroll         collapseOperands(rewriter, operands, operandUnitDims);
928431213c9Ssrcarroll     Value collapsedLhs = collapsedOperands[0];
929431213c9Ssrcarroll     Value collapsedRhs = collapsedOperands[1];
930431213c9Ssrcarroll     Value collapsedInit = collapsedOperands[2];
931431213c9Ssrcarroll     SmallVector<Type, 1> collapsedResultTy;
932431213c9Ssrcarroll     if (isa<RankedTensorType>(collapsedInit.getType()))
933431213c9Ssrcarroll       collapsedResultTy.push_back(collapsedInit.getType());
934431213c9Ssrcarroll     auto collapsedOp = rewriter.create<ToOpTy>(
935431213c9Ssrcarroll         loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
936431213c9Ssrcarroll         ValueRange{collapsedInit});
937431213c9Ssrcarroll     for (auto attr : contractionOp->getAttrs()) {
938431213c9Ssrcarroll       if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
939431213c9Ssrcarroll         continue;
940431213c9Ssrcarroll       collapsedOp->setAttr(attr.getName(), attr.getValue());
941431213c9Ssrcarroll     }
942431213c9Ssrcarroll 
943431213c9Ssrcarroll     auto results = contractionOp.getResults();
944431213c9Ssrcarroll     assert(results.size() < 2 && "expected at most one result");
945431213c9Ssrcarroll     if (results.empty()) {
946431213c9Ssrcarroll       rewriter.replaceOp(contractionOp, collapsedOp);
947431213c9Ssrcarroll     } else {
948431213c9Ssrcarroll       rewriter.replaceOp(
949431213c9Ssrcarroll           contractionOp,
950431213c9Ssrcarroll           expandResult(rewriter, collapsedOp.getResultTensors()[0],
951431213c9Ssrcarroll                        cast<RankedTensorType>(results[0].getType()),
952431213c9Ssrcarroll                        operandUnitDims[2]));
953431213c9Ssrcarroll     }
954431213c9Ssrcarroll 
955431213c9Ssrcarroll     return success();
956431213c9Ssrcarroll   }
957431213c9Ssrcarroll 
958431213c9Ssrcarroll   /// Populate `operandUnitDims` with 3 indices indicating the unit dim
959431213c9Ssrcarroll   /// for each operand that should be collapsed in this pattern.  If an
960431213c9Ssrcarroll   /// operand shouldn't be collapsed, the index should be negative.
961431213c9Ssrcarroll   virtual LogicalResult
962431213c9Ssrcarroll   getOperandUnitDims(LinalgOp op,
963431213c9Ssrcarroll                      SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
964431213c9Ssrcarroll };
965431213c9Ssrcarroll 
966431213c9Ssrcarroll /// Patterns for unbatching batched contraction ops
967431213c9Ssrcarroll template <typename FromOpTy, typename ToOpTy>
968431213c9Ssrcarroll struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
969431213c9Ssrcarroll   using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
970431213c9Ssrcarroll 
971431213c9Ssrcarroll   /// Look for unit batch dims to collapse.
972431213c9Ssrcarroll   LogicalResult
973431213c9Ssrcarroll   getOperandUnitDims(LinalgOp op,
974431213c9Ssrcarroll                      SmallVectorImpl<int64_t> &operandUnitDims) const override {
975431213c9Ssrcarroll     FailureOr<ContractionDimensions> maybeContractionDims =
976431213c9Ssrcarroll         inferContractionDims(op);
977431213c9Ssrcarroll     if (failed(maybeContractionDims)) {
978431213c9Ssrcarroll       LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
979431213c9Ssrcarroll       return failure();
980431213c9Ssrcarroll     }
981431213c9Ssrcarroll     ContractionDimensions contractionDims = maybeContractionDims.value();
982431213c9Ssrcarroll 
983431213c9Ssrcarroll     if (contractionDims.batch.size() != 1)
984431213c9Ssrcarroll       return failure();
985431213c9Ssrcarroll     auto batchDim = contractionDims.batch[0];
986431213c9Ssrcarroll     SmallVector<std::pair<Value, unsigned>, 3> bOperands;
987431213c9Ssrcarroll     op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
988431213c9Ssrcarroll     if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
989431213c9Ssrcarroll           return cast<ShapedType>(std::get<0>(pair).getType())
990431213c9Ssrcarroll                      .getShape()[std::get<1>(pair)] != 1;
991431213c9Ssrcarroll         })) {
992431213c9Ssrcarroll       LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
993431213c9Ssrcarroll       return failure();
994431213c9Ssrcarroll     }
995431213c9Ssrcarroll 
996431213c9Ssrcarroll     operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
997431213c9Ssrcarroll                                            std::get<1>(bOperands[1]),
998431213c9Ssrcarroll                                            std::get<1>(bOperands[2])};
999431213c9Ssrcarroll     return success();
1000431213c9Ssrcarroll   }
1001431213c9Ssrcarroll };
1002431213c9Ssrcarroll 
1003431213c9Ssrcarroll /// Patterns for reducing non-batch dimensions
1004431213c9Ssrcarroll template <typename FromOpTy, typename ToOpTy>
1005431213c9Ssrcarroll struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1006431213c9Ssrcarroll   using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1007431213c9Ssrcarroll 
1008431213c9Ssrcarroll   /// Helper for determining whether the lhs/init or rhs/init are reduced.
1009431213c9Ssrcarroll   static bool constexpr reduceLeft =
1010431213c9Ssrcarroll       (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1011431213c9Ssrcarroll        std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1012431213c9Ssrcarroll       (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
1013431213c9Ssrcarroll        std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1014431213c9Ssrcarroll       (std::is_same_v<FromOpTy, MatmulOp> &&
1015431213c9Ssrcarroll        std::is_same_v<ToOpTy, VecmatOp>) ||
1016431213c9Ssrcarroll       (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
1017431213c9Ssrcarroll        std::is_same_v<ToOpTy, VecmatOp>) ||
1018431213c9Ssrcarroll       (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1019431213c9Ssrcarroll 
1020431213c9Ssrcarroll   /// Look for non-batch spatial dims to collapse.
1021431213c9Ssrcarroll   LogicalResult
1022431213c9Ssrcarroll   getOperandUnitDims(LinalgOp op,
1023431213c9Ssrcarroll                      SmallVectorImpl<int64_t> &operandUnitDims) const override {
1024431213c9Ssrcarroll     FailureOr<ContractionDimensions> maybeContractionDims =
1025431213c9Ssrcarroll         inferContractionDims(op);
1026431213c9Ssrcarroll     if (failed(maybeContractionDims)) {
1027431213c9Ssrcarroll       LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
1028431213c9Ssrcarroll       return failure();
1029431213c9Ssrcarroll     }
1030431213c9Ssrcarroll     ContractionDimensions contractionDims = maybeContractionDims.value();
1031431213c9Ssrcarroll 
1032431213c9Ssrcarroll     if constexpr (reduceLeft) {
1033431213c9Ssrcarroll       auto m = contractionDims.m[0];
1034431213c9Ssrcarroll       SmallVector<std::pair<Value, unsigned>, 2> mOperands;
1035431213c9Ssrcarroll       op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1036431213c9Ssrcarroll       if (mOperands.size() != 2)
1037431213c9Ssrcarroll         return failure();
1038431213c9Ssrcarroll       if (llvm::all_of(mOperands, [](auto pair) {
1039431213c9Ssrcarroll             return cast<ShapedType>(std::get<0>(pair).getType())
1040431213c9Ssrcarroll                        .getShape()[std::get<1>(pair)] == 1;
1041431213c9Ssrcarroll           })) {
1042431213c9Ssrcarroll         operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
1043431213c9Ssrcarroll                                                std::get<1>(mOperands[1])};
1044431213c9Ssrcarroll         return success();
1045431213c9Ssrcarroll       }
1046431213c9Ssrcarroll     } else {
1047431213c9Ssrcarroll       auto n = contractionDims.n[0];
1048431213c9Ssrcarroll       SmallVector<std::pair<Value, unsigned>, 2> nOperands;
1049431213c9Ssrcarroll       op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1050431213c9Ssrcarroll       if (nOperands.size() != 2)
1051431213c9Ssrcarroll         return failure();
1052431213c9Ssrcarroll       if (llvm::all_of(nOperands, [](auto pair) {
1053431213c9Ssrcarroll             return cast<ShapedType>(std::get<0>(pair).getType())
1054431213c9Ssrcarroll                        .getShape()[std::get<1>(pair)] == 1;
1055431213c9Ssrcarroll           })) {
1056431213c9Ssrcarroll         operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
1057431213c9Ssrcarroll                                                std::get<1>(nOperands[1])};
1058431213c9Ssrcarroll         return success();
1059431213c9Ssrcarroll       }
1060431213c9Ssrcarroll     }
1061431213c9Ssrcarroll     LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
1062431213c9Ssrcarroll     return failure();
1063431213c9Ssrcarroll   }
1064431213c9Ssrcarroll };
1065431213c9Ssrcarroll 
1066431213c9Ssrcarroll } // namespace
1067431213c9Ssrcarroll 
1068431213c9Ssrcarroll void mlir::linalg::populateContractionOpRankReducingPatterns(
1069431213c9Ssrcarroll     RewritePatternSet &patterns) {
1070431213c9Ssrcarroll   MLIRContext *context = patterns.getContext();
1071431213c9Ssrcarroll   // Unbatching patterns for unit batch size
1072431213c9Ssrcarroll   patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1073431213c9Ssrcarroll   patterns
1074431213c9Ssrcarroll       .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
1075431213c9Ssrcarroll           context);
1076431213c9Ssrcarroll   patterns
1077431213c9Ssrcarroll       .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
1078431213c9Ssrcarroll           context);
1079431213c9Ssrcarroll   patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1080431213c9Ssrcarroll   patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1081431213c9Ssrcarroll 
1082431213c9Ssrcarroll   // Non-batch rank 1 reducing patterns
1083431213c9Ssrcarroll   patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1084431213c9Ssrcarroll   patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1085431213c9Ssrcarroll   patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
1086431213c9Ssrcarroll   patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
1087431213c9Ssrcarroll   // Batch rank 1 reducing patterns
1088431213c9Ssrcarroll   patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1089431213c9Ssrcarroll   patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1090431213c9Ssrcarroll   patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
1091431213c9Ssrcarroll       context);
1092431213c9Ssrcarroll   patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
1093431213c9Ssrcarroll       context);
1094431213c9Ssrcarroll 
1095431213c9Ssrcarroll   // Non-batch rank 0 reducing patterns
1096431213c9Ssrcarroll   patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1097431213c9Ssrcarroll   patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
1098431213c9Ssrcarroll }
1099