xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to remove usage of unit-extent dimensions
10 // to specify broadcasting in favor of more canonical representation of the
11 // computation
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Linalg/Passes.h"
16 
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
25 #include "mlir/Dialect/Tensor/Utils/Utils.h"
26 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
27 #include "mlir/IR/AffineExpr.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/Transforms/FoldUtils.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #include "llvm/ADT/SetVector.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 
36 namespace mlir {
37 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
38 #include "mlir/Dialect/Linalg/Passes.h.inc"
39 } // namespace mlir
40 
41 #define DEBUG_TYPE "linalg-drop-unit-dims"
42 
43 using namespace mlir;
44 using namespace mlir::linalg;
45 
46 namespace {
47 /// Pattern to move init operands to ins when all the loops are parallel and
48 /// blockArgument corresponding to init is used in the region. This is a fix-up
49 /// when unit reduction dimensions are all folded away. In this context, it
50 /// becomes a elementwise generic op. E.g., it converts
51 ///
52 ///  %0 = tensor.empty() : tensor<1x1xf32>
53 ///  %1 = linalg.fill
54 ///    ins(%cst : f32)
55 ///    outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
56 ///  %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
57 ///                                        affine_map<(d0) -> (0, d0)>],
58 ///                       iterator_types = ["parallel"]}
59 ///    ins(%arg0 : tensor<1x?x1x1xf32>)
60 ///    outs(%1 : tensor<1x1xf32>) {
61 ///  ^bb0(%in: f32, %out: f32):
62 ///    %3 = arith.addf %in, %out : f32
63 ///    linalg.yield %3 : f32
64 ///  } -> tensor<1x1xf32>
65 ///
66 ///  into
67 ///
68 ///  %0 = tensor.empty() : tensor<1x1xf32>
69 ///  %1 = linalg.fill
70 ///    ins(%cst : f32)
71 ///    outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
72 ///  %2 = tensor.empty() : tensor<1x1xf32>
73 ///  %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
74 ///                                        affine_map<(d0) -> (0, d0)>,
75 ///                                        affine_map<(d0) -> (0, d0)>],
76 ///                       iterator_types = ["parallel"]}
77 ///   ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
78 ///   outs(%2 : tensor<1x1xf32>) {
79 ///  ^bb0(%in: f32, %in_0: f32, %out: f32):
80 ///    %4 = arith.addf %in, %in_0 : f32
81 ///    linalg.yield %4 : f32
82 ///  } -> tensor<1x1xf32>
83 struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
84   using OpRewritePattern<GenericOp>::OpRewritePattern;
85   LogicalResult matchAndRewrite(GenericOp genericOp,
86                                 PatternRewriter &rewriter) const override {
87     if (!genericOp.hasPureTensorSemantics())
88       return failure();
89     if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
90       return failure();
91 
92     auto outputOperands = genericOp.getDpsInitsMutable();
93     SetVector<OpOperand *> candidates;
94     for (OpOperand &op : outputOperands) {
95       if (genericOp.getMatchingBlockArgument(&op).use_empty())
96         continue;
97       candidates.insert(&op);
98     }
99 
100     if (candidates.empty())
101       return failure();
102 
103     // Compute the modified indexing maps.
104     int64_t origNumInput = genericOp.getNumDpsInputs();
105     SmallVector<Value> newInputOperands = genericOp.getDpsInputs();
106     SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
107     SmallVector<AffineMap> newIndexingMaps;
108     newIndexingMaps.append(indexingMaps.begin(),
109                            std::next(indexingMaps.begin(), origNumInput));
110     for (OpOperand *op : candidates) {
111       newInputOperands.push_back(op->get());
112       newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
113     }
114     newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
115                            indexingMaps.end());
116 
117     Location loc = genericOp.getLoc();
118     SmallVector<Value> newOutputOperands =
119         llvm::to_vector(genericOp.getDpsInits());
120     for (OpOperand *op : candidates) {
121       OpBuilder::InsertionGuard guard(rewriter);
122       rewriter.setInsertionPointAfterValue(op->get());
123       auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
124       auto empty = rewriter.create<tensor::EmptyOp>(
125           loc, tensor::getMixedSizes(rewriter, loc, op->get()), elemType);
126 
127       unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
128       newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
129     }
130 
131     auto newOp = rewriter.create<GenericOp>(
132         loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
133         newIndexingMaps, genericOp.getIteratorTypesArray(),
134         /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
135 
136     OpBuilder::InsertionGuard guard(rewriter);
137     Region &region = newOp.getRegion();
138     Block *block = rewriter.createBlock(&region);
139     IRMapping mapper;
140     for (auto bbarg : genericOp.getRegionInputArgs())
141       mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
142 
143     for (OpOperand *op : candidates) {
144       BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
145       mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
146     }
147 
148     for (OpOperand &op : outputOperands) {
149       BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
150       if (candidates.count(&op))
151         block->addArgument(bbarg.getType(), loc);
152       else
153         mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
154     }
155 
156     for (auto &op : genericOp.getBody()->getOperations()) {
157       rewriter.clone(op, mapper);
158     }
159     rewriter.replaceOp(genericOp, newOp.getResults());
160 
161     return success();
162   }
163 };
164 } // namespace
165 
166 //===---------------------------------------------------------------------===//
167 // Drop loops that are unit-extents within Linalg operations.
168 //===---------------------------------------------------------------------===//
169 
170 /// Implements a pass that canonicalizes the uses of unit-extent dimensions for
171 /// broadcasting. For example,
172 ///
173 /// ```mlir
174 /// #accesses = [
175 ///   affine_map<(d0, d1) -> (0, d1)>,
176 ///   affine_map<(d0, d1) -> (d0, 0)>,
177 ///   affine_map<(d0, d1) -> (d0, d1)>
178 /// ]
179 ///
180 /// #trait = {
181 ///   indexing_maps = #accesses,
182 ///   iterator_types = ["parallel", "parallel"],
183 ///   library_call = "some_external_fn"
184 /// }
185 ///
186 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
187 /// tensor<5x5xf32>
188 /// {
189 ///   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
190 ///        tensor<5xf32> into tensor<1x5xf32>
191 ///   %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
192 ///        tensor<5xf32> into tensor<5x1xf32>
193 ///   %2 = linalg.generic #trait %0, %1 {
194 ///        ^bb0(%arg2: f32, %arg3: f32):
195 ///          %3 = arith.addf %arg2, %arg3 : f32
196 ///          linalg.yield %3 : f32
197 ///        } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
198 ///   return %2 : tensor<5x5xf32>
199 /// }
200 ///
201 /// would canonicalize to
202 ///
203 /// ```mlir
204 /// #accesses = [
205 ///   affine_map<(d0, d1) -> (d1)>,
206 ///   affine_map<(d0, d1) -> (d0)>,
207 ///   affine_map<(d0, d1) -> (d0, d1)>
208 /// ]
209 ///
210 /// #trait = {
211 ///   indexing_maps = #accesses,
212 ///   iterator_types = ["parallel", "parallel"],
213 ///   library_call = "some_external_fn"
214 /// }
215 ///
216 /// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
217 /// tensor<5x5xf32>
218 /// {
219 ///   %0 = linalg.generic #trait %arg0, %arg1 {
220 ///        ^bb0(%arg2: f32, %arg3: f32):
221 ///          %3 = arith.addf %arg2, %arg3 : f32
222 ///          linalg.yield %3 : f32
223 ///        } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
224 ///   return %0 : tensor<5x5xf32>
225 /// }
226 
227 /// Update the index accesses of linalg operations having index semantics.
228 static void
229 replaceUnitDimIndexOps(GenericOp genericOp,
230                        const llvm::SmallDenseSet<unsigned> &unitDims,
231                        RewriterBase &rewriter) {
232   for (IndexOp indexOp :
233        llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
234     OpBuilder::InsertionGuard guard(rewriter);
235     rewriter.setInsertionPoint(indexOp);
236     if (unitDims.count(indexOp.getDim()) != 0) {
237       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
238     } else {
239       // Update the dimension of the index operation if needed.
240       unsigned droppedDims = llvm::count_if(
241           unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
242       if (droppedDims != 0)
243         rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
244                                              indexOp.getDim() - droppedDims);
245     }
246   }
247 }
248 
249 /// Expand the given `value` so that the type matches the type of `origDest`.
250 /// The `reassociation` is used when `rankReductionStrategy` is set to
251 /// `RankReductionStrategy::ReassociativeReshape`.
252 static Value
253 expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
254             ArrayRef<ReassociationIndices> reassociation,
255             ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
256   // There are no results for memref outputs.
257   auto origResultType = cast<RankedTensorType>(origDest.getType());
258   if (rankReductionStrategy ==
259       ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
260     unsigned rank = origResultType.getRank();
261     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
262     SmallVector<OpFoldResult> sizes =
263         tensor::getMixedSizes(rewriter, loc, origDest);
264     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
265     return rewriter.createOrFold<tensor::InsertSliceOp>(
266         loc, result, origDest, offsets, sizes, strides);
267   }
268 
269   assert(rankReductionStrategy ==
270              ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
271          "unknown rank reduction strategy");
272   return rewriter
273       .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
274       .getResult();
275 }
276 
277 /// Collapse the given `value` so that the type matches the type of
278 /// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
279 /// set to `RankReductionStrategy::ReassociativeReshape`.
280 static Value collapseValue(
281     RewriterBase &rewriter, Location loc, Value operand,
282     ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
283     ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
284   if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
285     if (rankReductionStrategy ==
286         ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
287       FailureOr<Value> rankReducingExtract =
288           memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
289                                                 targetShape);
290       assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
291       return *rankReducingExtract;
292     }
293 
294     assert(
295         rankReductionStrategy ==
296             ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
297         "unknown rank reduction strategy");
298     MemRefLayoutAttrInterface layout;
299     auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
300                                       layout, memrefType.getMemorySpace());
301     return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
302                                                     reassociation);
303   }
304   if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
305     if (rankReductionStrategy ==
306         ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
307       FailureOr<Value> rankReducingExtract =
308           tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
309                                                      targetShape);
310       assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
311       return *rankReducingExtract;
312     }
313 
314     assert(
315         rankReductionStrategy ==
316             ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
317         "unknown rank reduction strategy");
318     auto targetType =
319         RankedTensorType::get(targetShape, tensorType.getElementType());
320     return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
321                                                     reassociation);
322   }
323   llvm_unreachable("unsupported operand type");
324 }
325 
326 /// Compute the modified metadata for an operands of operation
327 /// whose unit dims are being dropped. Return the new indexing map
328 /// to use, the shape of the operand in the replacement op
329 /// and the `reassocation` to use to go from original operand shape
330 /// to modified operand shape.
331 struct UnitExtentReplacementInfo {
332   AffineMap indexMap;
333   SmallVector<ReassociationIndices> reassociation;
334   SmallVector<int64_t> targetShape;
335 };
336 static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
337     MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
338     llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
339     ArrayRef<AffineExpr> dimReplacements) {
340   UnitExtentReplacementInfo info;
341   ReassociationIndices reassociationGroup;
342   SmallVector<AffineExpr> newIndexExprs;
343   AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
344   ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
345   ArrayRef<AffineExpr> exprs = indexingMap.getResults();
346 
347   auto isUnitDim = [&](unsigned dim) {
348     if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
349       unsigned oldPosition = dimExpr.getPosition();
350       return !oldDimsToNewDimsMap.count(oldPosition) &&
351              (operandShape[dim] == 1);
352     }
353     // Handle the other case where the shape is 1, and is accessed using a
354     // constant 0.
355     if (operandShape[dim] == 1) {
356       auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
357       return constAffineExpr && constAffineExpr.getValue() == 0;
358     }
359     return false;
360   };
361 
362   unsigned dim = 0;
363   while (dim < operandShape.size() && isUnitDim(dim))
364     reassociationGroup.push_back(dim++);
365   while (dim < operandShape.size()) {
366     assert(!isUnitDim(dim) && "expected non unit-extent");
367     reassociationGroup.push_back(dim);
368     AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
369     newIndexExprs.push_back(newExpr);
370     info.targetShape.push_back(operandShape[dim]);
371     ++dim;
372     // Fold all following dimensions that are unit-extent.
373     while (dim < operandShape.size() && isUnitDim(dim)) {
374       reassociationGroup.push_back(dim++);
375     }
376     info.reassociation.push_back(reassociationGroup);
377     reassociationGroup.clear();
378   }
379   info.indexMap =
380       AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(),
381                      newIndexExprs, context);
382   return info;
383 }
384 
385 FailureOr<DropUnitDimsResult>
386 linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
387                      const ControlDropUnitDims &options) {
388   SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
389   if (indexingMaps.empty())
390     return failure();
391 
392   // 1. Check if any of the iteration dimensions are unit-trip count. They will
393   //    end up being unit-trip count if they are used to index into a unit-dim
394   //    tensor/memref.
395   AffineMap invertedMap =
396       inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
397   if (!invertedMap) {
398     return rewriter.notifyMatchFailure(genericOp,
399                                        "invalid indexing maps for operation");
400   }
401   SmallVector<int64_t> dims = genericOp.getStaticShape();
402 
403   // 1a. Get the allowed list of dimensions to drop from the `options`.
404   SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
405   if (allowedUnitDims.empty()) {
406     return rewriter.notifyMatchFailure(
407         genericOp, "control function returns no allowed unit dims to prune");
408   }
409   llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
410                                                allowedUnitDims.end());
411   llvm::SmallDenseSet<unsigned> unitDims;
412   for (const auto &expr : enumerate(invertedMap.getResults())) {
413     if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
414       if (dims[dimExpr.getPosition()] == 1 &&
415           unitDimsFilter.count(expr.index()))
416         unitDims.insert(expr.index());
417     }
418   }
419 
420   // 2. Compute the iterator types of the modified op by dropping the one-trip
421   //    count loops.
422   SmallVector<utils::IteratorType> newIteratorTypes;
423   llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
424   SmallVector<AffineExpr> dimReplacements;
425   unsigned newDims = 0;
426   for (auto [index, attr] :
427        llvm::enumerate(genericOp.getIteratorTypesArray())) {
428     if (unitDims.count(index)) {
429       dimReplacements.push_back(
430           getAffineConstantExpr(0, rewriter.getContext()));
431     } else {
432       newIteratorTypes.push_back(attr);
433       oldDimToNewDimMap[index] = newDims;
434       dimReplacements.push_back(
435           getAffineDimExpr(newDims, rewriter.getContext()));
436       newDims++;
437     }
438   }
439 
440   // 3. For each of the operands, find the
441   //    - modified affine map to use.
442   //    - shape of the operands after the unit-dims are dropped.
443   //    - the reassociation indices used to convert from the original
444   //      operand type to modified operand (needed only when using reshapes
445   //      for rank reduction strategy)
446   // Note that the indexing maps might need changing even if there are no
447   // unit dimensions that are dropped to handle cases where `0` is used to
448   // access a unit-extent tensor. Consider moving this out of this specific
449   // transformation as a stand-alone transformation. Kept here right now due
450   // to legacy.
451   SmallVector<AffineMap> newIndexingMaps;
452   SmallVector<SmallVector<ReassociationIndices>> reassociations;
453   SmallVector<SmallVector<int64_t>> targetShapes;
454   SmallVector<bool> collapsed;
455   auto hasCollapsibleType = [](OpOperand &operand) {
456     Type operandType = operand.get().getType();
457     if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
458       return memrefOperandType.getLayout().isIdentity();
459     }
460     if (auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
461       return tensorOperandType.getEncoding() == nullptr;
462     }
463     return false;
464   };
465   for (OpOperand &opOperand : genericOp->getOpOperands()) {
466     auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
467     ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
468     if (!hasCollapsibleType(opOperand)) {
469       AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
470           dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
471       newIndexingMaps.push_back(newIndexingMap);
472       targetShapes.push_back(llvm::to_vector(shape));
473       collapsed.push_back(false);
474       reassociations.push_back({});
475       continue;
476     }
477     auto replacementInfo = dropUnitExtentFromOperandMetadata(
478         rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
479         dimReplacements);
480     reassociations.push_back(replacementInfo.reassociation);
481     newIndexingMaps.push_back(replacementInfo.indexMap);
482     targetShapes.push_back(replacementInfo.targetShape);
483     collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
484                           indexingMap.getNumResults()));
485   }
486 
487   // Abort if the indexing maps of the result operation are not invertible
488   // (i.e. not legal) or if no dimension was reduced.
489   if (newIndexingMaps == indexingMaps ||
490       !inversePermutation(
491           concatAffineMaps(newIndexingMaps, rewriter.getContext())))
492     return failure();
493 
494   Location loc = genericOp.getLoc();
495   // 4. For each of the operands, collapse the operand to convert
496   //    from original shape to shape in the modified operation if needed,
497   //    either through use of reshapes or rank-reducing slices as
498   //    specified in `options`.
499   SmallVector<Value> newOperands;
500   for (OpOperand &opOperand : genericOp->getOpOperands()) {
501     int64_t idx = opOperand.getOperandNumber();
502     if (!collapsed[idx]) {
503       newOperands.push_back(opOperand.get());
504       continue;
505     }
506     newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(),
507                                         targetShapes[idx], reassociations[idx],
508                                         options.rankReductionStrategy));
509   }
510 
511   // 5. Create the `linalg.generic` operation with the new operands,
512   //    indexing maps, iterator types and result types.
513   ArrayRef<Value> newInputs =
514       ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
515   ArrayRef<Value> newOutputs =
516       ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
517   SmallVector<Type> resultTypes;
518   resultTypes.reserve(genericOp.getNumResults());
519   for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
520     resultTypes.push_back(newOutputs[i].getType());
521   GenericOp replacementOp =
522       rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
523                                  newIndexingMaps, newIteratorTypes);
524   rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
525                               replacementOp.getRegion().begin());
526   // 5a. Replace `linalg.index` operations that refer to the dropped unit
527   //     dimensions.
528   replaceUnitDimIndexOps(replacementOp, unitDims, rewriter);
529 
530   // 6. If any result type changes, insert a reshape/slice to convert from the
531   //    original type to the new type.
532   SmallVector<Value> resultReplacements;
533   for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) {
534     unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
535     Value origDest = genericOp.getDpsInitOperand(index)->get();
536     if (!collapsed[opOperandIndex]) {
537       resultReplacements.push_back(result);
538       continue;
539     }
540     Value expandedValue = expandValue(rewriter, loc, result, origDest,
541                                       reassociations[opOperandIndex],
542                                       options.rankReductionStrategy);
543     resultReplacements.push_back(expandedValue);
544   }
545 
546   return DropUnitDimsResult{replacementOp, resultReplacements};
547 }
548 
549 namespace {
550 struct DropUnitDims : public OpRewritePattern<GenericOp> {
551   DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
552                PatternBenefit benefit = 1)
553       : OpRewritePattern(context, benefit), options(std::move(options)) {}
554 
555   LogicalResult matchAndRewrite(GenericOp genericOp,
556                                 PatternRewriter &rewriter) const override {
557     FailureOr<DropUnitDimsResult> result =
558         dropUnitDims(rewriter, genericOp, options);
559     if (failed(result)) {
560       return failure();
561     }
562     rewriter.replaceOp(genericOp, result->replacements);
563     return success();
564   }
565 
566 private:
567   ControlDropUnitDims options;
568 };
569 } // namespace
570 
571 //===---------------------------------------------------------------------===//
572 // Drop dimensions that are unit-extents within tensor operations.
573 //===---------------------------------------------------------------------===//
574 
575 namespace {
576 struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
577   DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
578                   PatternBenefit benefit = 1)
579       : OpRewritePattern(context, benefit), options(std::move(options)) {}
580 
581   LogicalResult matchAndRewrite(tensor::PadOp padOp,
582                                 PatternRewriter &rewriter) const override {
583     // 1a. Get the allowed list of dimensions to drop from the `options`.
584     SmallVector<unsigned> allowedUnitDims = options.controlFn(padOp);
585     if (allowedUnitDims.empty()) {
586       return rewriter.notifyMatchFailure(
587           padOp, "control function returns no allowed unit dims to prune");
588     }
589 
590     if (padOp.getSourceType().getEncoding()) {
591       return rewriter.notifyMatchFailure(
592           padOp, "cannot collapse dims of tensor with encoding");
593     }
594 
595     // Fail for non-constant padding values. The body of the pad could
596     // depend on the padding indices and/or properties of the padded
597     // tensor so for now we fail.
598     // TODO: Support non-constant padding values.
599     Value paddingVal = padOp.getConstantPaddingValue();
600     if (!paddingVal) {
601       return rewriter.notifyMatchFailure(
602           padOp, "unimplemented: non-constant padding value");
603     }
604 
605     ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
606     int64_t padRank = sourceShape.size();
607 
608     auto isStaticZero = [](OpFoldResult f) {
609       std::optional<int64_t> maybeInt = getConstantIntValue(f);
610       return maybeInt && *maybeInt == 0;
611     };
612 
613     llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
614                                                  allowedUnitDims.end());
615     llvm::SmallDenseSet<unsigned> unitDims;
616     SmallVector<int64_t> newShape;
617     SmallVector<OpFoldResult> newLowPad;
618     SmallVector<OpFoldResult> newHighPad;
619     for (const auto [dim, size, low, high] :
620          zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
621                    padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
622       if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
623           isStaticZero(high)) {
624         unitDims.insert(dim);
625       } else {
626         newShape.push_back(size);
627         newLowPad.push_back(low);
628         newHighPad.push_back(high);
629       }
630     }
631 
632     if (unitDims.empty()) {
633       return rewriter.notifyMatchFailure(padOp, "no unit dims to collapse");
634     }
635 
636     ReassociationIndices reassociationGroup;
637     SmallVector<ReassociationIndices> reassociationMap;
638     int64_t dim = 0;
639     while (dim < padRank && unitDims.contains(dim))
640       reassociationGroup.push_back(dim++);
641     while (dim < padRank) {
642       assert(!unitDims.contains(dim) && "expected non unit-extent");
643       reassociationGroup.push_back(dim);
644       dim++;
645       // Fold all following dimensions that are unit-extent.
646       while (dim < padRank && unitDims.contains(dim))
647         reassociationGroup.push_back(dim++);
648       reassociationMap.push_back(reassociationGroup);
649       reassociationGroup.clear();
650     }
651 
652     Value collapsedSource =
653         collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
654                       reassociationMap, options.rankReductionStrategy);
655 
656     auto newPadOp = rewriter.create<tensor::PadOp>(
657         padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad,
658         newHighPad, paddingVal, padOp.getNofold());
659 
660     Value dest = padOp.getResult();
661     if (options.rankReductionStrategy ==
662         ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
663       SmallVector<OpFoldResult> expandedSizes;
664       int64_t numUnitDims = 0;
665       for (auto dim : llvm::seq(static_cast<int64_t>(0), padRank)) {
666         if (unitDims.contains(dim)) {
667           expandedSizes.push_back(rewriter.getIndexAttr(1));
668           numUnitDims++;
669           continue;
670         }
671         expandedSizes.push_back(tensor::getMixedSize(
672             rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
673       }
674       dest = rewriter.create<tensor::EmptyOp>(
675           padOp.getLoc(), expandedSizes,
676           padOp.getResultType().getElementType());
677     }
678 
679     Value expandedValue =
680         expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
681                     reassociationMap, options.rankReductionStrategy);
682     rewriter.replaceOp(padOp, expandedValue);
683     return success();
684   }
685 
686 private:
687   ControlDropUnitDims options;
688 };
689 } // namespace
690 
691 namespace {
692 /// Convert `extract_slice` operations to rank-reduced versions.
693 struct RankReducedExtractSliceOp
694     : public OpRewritePattern<tensor::ExtractSliceOp> {
695   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
696 
697   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
698                                 PatternRewriter &rewriter) const override {
699     RankedTensorType resultType = sliceOp.getType();
700     SmallVector<OpFoldResult> targetShape;
701     for (auto size : resultType.getShape())
702       targetShape.push_back(rewriter.getIndexAttr(size));
703     auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
704     if (!reassociation ||
705         reassociation->size() == static_cast<size_t>(resultType.getRank()))
706       return failure();
707 
708     SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
709     SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
710     SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
711     auto rankReducedType = cast<RankedTensorType>(
712         tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
713             reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
714             strides));
715 
716     Location loc = sliceOp.getLoc();
717     Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
718         loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
719     rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
720         sliceOp, resultType, newSlice, *reassociation);
721     return success();
722   }
723 };
724 
725 /// Convert `insert_slice` operations to rank-reduced versions.
726 /// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
727 template <typename InsertOpTy>
728 struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
729   using OpRewritePattern<InsertOpTy>::OpRewritePattern;
730 
731   LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
732                                 PatternRewriter &rewriter) const override {
733     RankedTensorType sourceType = insertSliceOp.getSourceType();
734     SmallVector<OpFoldResult> targetShape;
735     for (auto size : sourceType.getShape())
736       targetShape.push_back(rewriter.getIndexAttr(size));
737     auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
738     if (!reassociation ||
739         reassociation->size() == static_cast<size_t>(sourceType.getRank()))
740       return failure();
741 
742     Location loc = insertSliceOp.getLoc();
743     tensor::CollapseShapeOp reshapedSource;
744     {
745       OpBuilder::InsertionGuard g(rewriter);
746       // The only difference between InsertSliceOp and ParallelInsertSliceOp
747       // is the insertion point is just before the ParallelCombiningOp in the
748       // parallel case.
749       if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
750         rewriter.setInsertionPoint(insertSliceOp->getParentOp());
751       reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
752           loc, insertSliceOp.getSource(), *reassociation);
753     }
754     rewriter.replaceOpWithNewOp<InsertOpTy>(
755         insertSliceOp, reshapedSource, insertSliceOp.getDest(),
756         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
757         insertSliceOp.getMixedStrides());
758     return success();
759   }
760 };
761 } // namespace
762 
763 /// Patterns that are used to canonicalize the use of unit-extent dims for
764 /// broadcasting.
765 static void
766 populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
767                                               ControlDropUnitDims &options) {
768   auto *context = patterns.getContext();
769   patterns.add<DropUnitDims>(context, options);
770   patterns.add<DropPadUnitDims>(context, options);
771   // TODO: Patterns unrelated to unit dim folding should be factored out.
772   patterns.add<RankReducedExtractSliceOp,
773                RankReducedInsertSliceOp<tensor::InsertSliceOp>,
774                RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
775       context);
776   linalg::FillOp::getCanonicalizationPatterns(patterns, context);
777   tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
778   tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
779   tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
780   tensor::populateFoldTensorEmptyPatterns(patterns);
781   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
782   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
783 }
784 
785 static void
786 populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
787                                             ControlDropUnitDims &options) {
788   auto *context = patterns.getContext();
789   patterns.add<DropUnitDims>(context, options);
790   patterns.add<DropPadUnitDims>(context, options);
791   // TODO: Patterns unrelated to unit dim folding should be factored out.
792   linalg::FillOp::getCanonicalizationPatterns(patterns, context);
793   tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
794   tensor::populateFoldTensorEmptyPatterns(patterns);
795   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
796   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
797 }
798 
799 void mlir::linalg::populateFoldUnitExtentDimsPatterns(
800     RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {
801   if (options.rankReductionStrategy ==
802       linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
803     populateFoldUnitExtentDimsViaSlicesPatterns(patterns, options);
804   } else if (options.rankReductionStrategy ==
805              linalg::ControlDropUnitDims::RankReductionStrategy::
806                  ReassociativeReshape) {
807     populateFoldUnitExtentDimsViaReshapesPatterns(patterns, options);
808   }
809 }
810 
811 void mlir::linalg::populateMoveInitOperandsToInputPattern(
812     RewritePatternSet &patterns) {
813   patterns.add<MoveInitOperandsToInput>(patterns.getContext());
814 }
815 
816 namespace {
817 /// Pass that removes unit-extent dims within generic ops.
818 struct LinalgFoldUnitExtentDimsPass
819     : public impl::LinalgFoldUnitExtentDimsPassBase<
820           LinalgFoldUnitExtentDimsPass> {
821   using impl::LinalgFoldUnitExtentDimsPassBase<
822       LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
823   void runOnOperation() override {
824     Operation *op = getOperation();
825     MLIRContext *context = op->getContext();
826     RewritePatternSet patterns(context);
827     ControlDropUnitDims options;
828     if (useRankReducingSlices) {
829       options.rankReductionStrategy = linalg::ControlDropUnitDims::
830           RankReductionStrategy::ExtractInsertSlice;
831     }
832     linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
833     populateMoveInitOperandsToInputPattern(patterns);
834     (void)applyPatternsGreedily(op, std::move(patterns));
835   }
836 };
837 
838 } // namespace
839 
840 namespace {
841 
842 /// Returns reassociation indices for collapsing/expanding a
843 /// tensor of rank `rank` at position `pos`.
844 static SmallVector<ReassociationIndices>
845 getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
846   SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
847   bool lastDim = pos == rank - 1;
848   if (rank > 2) {
849     for (int64_t i = 0; i < rank - 1; i++) {
850       if (i == pos || (lastDim && i == pos - 1))
851         reassociation[i] = ReassociationIndices{i, i + 1};
852       else if (i < pos)
853         reassociation[i] = ReassociationIndices{i};
854       else
855         reassociation[i] = ReassociationIndices{i + 1};
856     }
857   }
858   return reassociation;
859 }
860 
861 /// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
862 /// If `pos < 0`, then don't collapse.
863 static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
864                                     int64_t pos) {
865   if (pos < 0)
866     return val;
867   auto valType = cast<ShapedType>(val.getType());
868   SmallVector<int64_t> collapsedShape(valType.getShape());
869   collapsedShape.erase(collapsedShape.begin() + pos);
870   return collapseValue(
871       rewriter, val.getLoc(), val, collapsedShape,
872       getReassociationForReshapeAtDim(valType.getRank(), pos),
873       ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
874 }
875 
876 /// Base class for all rank reduction patterns for contraction ops
877 /// with unit dimensions.  All patterns should convert one named op
878 /// to another named op.  Intended to reduce only one iteration space dim
879 /// at a time.
880 /// Reducing multiple dims will happen with recusive application of
881 /// pattern rewrites.
882 template <typename FromOpTy, typename ToOpTy>
883 struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
884   using OpRewritePattern<FromOpTy>::OpRewritePattern;
885 
886   /// Collapse all collapsable operands.
887   SmallVector<Value>
888   collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
889                    ArrayRef<int64_t> operandCollapseDims) const {
890     assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
891            "expected 3 operands and dims");
892     return llvm::map_to_vector(
893         llvm::zip(operands, operandCollapseDims), [&](auto pair) {
894           return collapseSingletonDimAt(rewriter, std::get<0>(pair),
895                                         std::get<1>(pair));
896         });
897   }
898 
899   /// Expand result tensor.
900   Value expandResult(PatternRewriter &rewriter, Value result,
901                      RankedTensorType expandedType, int64_t dim) const {
902     return rewriter.create<tensor::ExpandShapeOp>(
903         result.getLoc(), expandedType, result,
904         getReassociationForReshapeAtDim(expandedType.getRank(), dim));
905   }
906 
907   LogicalResult matchAndRewrite(FromOpTy contractionOp,
908                                 PatternRewriter &rewriter) const override {
909 
910     auto loc = contractionOp.getLoc();
911     auto inputs = contractionOp.getDpsInputs();
912     auto inits = contractionOp.getDpsInits();
913     if (inputs.size() != 2 || inits.size() != 1)
914       return rewriter.notifyMatchFailure(contractionOp,
915                                          "expected 2 inputs and 1 init");
916     auto lhs = inputs[0];
917     auto rhs = inputs[1];
918     auto init = inits[0];
919     SmallVector<Value> operands{lhs, rhs, init};
920 
921     SmallVector<int64_t> operandUnitDims;
922     if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
923       return rewriter.notifyMatchFailure(contractionOp,
924                                          "no reducable dims found");
925 
926     SmallVector<Value> collapsedOperands =
927         collapseOperands(rewriter, operands, operandUnitDims);
928     Value collapsedLhs = collapsedOperands[0];
929     Value collapsedRhs = collapsedOperands[1];
930     Value collapsedInit = collapsedOperands[2];
931     SmallVector<Type, 1> collapsedResultTy;
932     if (isa<RankedTensorType>(collapsedInit.getType()))
933       collapsedResultTy.push_back(collapsedInit.getType());
934     auto collapsedOp = rewriter.create<ToOpTy>(
935         loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
936         ValueRange{collapsedInit});
937     for (auto attr : contractionOp->getAttrs()) {
938       if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
939         continue;
940       collapsedOp->setAttr(attr.getName(), attr.getValue());
941     }
942 
943     auto results = contractionOp.getResults();
944     assert(results.size() < 2 && "expected at most one result");
945     if (results.empty()) {
946       rewriter.replaceOp(contractionOp, collapsedOp);
947     } else {
948       rewriter.replaceOp(
949           contractionOp,
950           expandResult(rewriter, collapsedOp.getResultTensors()[0],
951                        cast<RankedTensorType>(results[0].getType()),
952                        operandUnitDims[2]));
953     }
954 
955     return success();
956   }
957 
958   /// Populate `operandUnitDims` with 3 indices indicating the unit dim
959   /// for each operand that should be collapsed in this pattern.  If an
960   /// operand shouldn't be collapsed, the index should be negative.
961   virtual LogicalResult
962   getOperandUnitDims(LinalgOp op,
963                      SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
964 };
965 
966 /// Patterns for unbatching batched contraction ops
967 template <typename FromOpTy, typename ToOpTy>
968 struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
969   using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
970 
971   /// Look for unit batch dims to collapse.
972   LogicalResult
973   getOperandUnitDims(LinalgOp op,
974                      SmallVectorImpl<int64_t> &operandUnitDims) const override {
975     FailureOr<ContractionDimensions> maybeContractionDims =
976         inferContractionDims(op);
977     if (failed(maybeContractionDims)) {
978       LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
979       return failure();
980     }
981     ContractionDimensions contractionDims = maybeContractionDims.value();
982 
983     if (contractionDims.batch.size() != 1)
984       return failure();
985     auto batchDim = contractionDims.batch[0];
986     SmallVector<std::pair<Value, unsigned>, 3> bOperands;
987     op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
988     if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
989           return cast<ShapedType>(std::get<0>(pair).getType())
990                      .getShape()[std::get<1>(pair)] != 1;
991         })) {
992       LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
993       return failure();
994     }
995 
996     operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
997                                            std::get<1>(bOperands[1]),
998                                            std::get<1>(bOperands[2])};
999     return success();
1000   }
1001 };
1002 
1003 /// Patterns for reducing non-batch dimensions
1004 template <typename FromOpTy, typename ToOpTy>
1005 struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1006   using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1007 
1008   /// Helper for determining whether the lhs/init or rhs/init are reduced.
1009   static bool constexpr reduceLeft =
1010       (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1011        std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1012       (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
1013        std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1014       (std::is_same_v<FromOpTy, MatmulOp> &&
1015        std::is_same_v<ToOpTy, VecmatOp>) ||
1016       (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
1017        std::is_same_v<ToOpTy, VecmatOp>) ||
1018       (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1019 
1020   /// Look for non-batch spatial dims to collapse.
1021   LogicalResult
1022   getOperandUnitDims(LinalgOp op,
1023                      SmallVectorImpl<int64_t> &operandUnitDims) const override {
1024     FailureOr<ContractionDimensions> maybeContractionDims =
1025         inferContractionDims(op);
1026     if (failed(maybeContractionDims)) {
1027       LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
1028       return failure();
1029     }
1030     ContractionDimensions contractionDims = maybeContractionDims.value();
1031 
1032     if constexpr (reduceLeft) {
1033       auto m = contractionDims.m[0];
1034       SmallVector<std::pair<Value, unsigned>, 2> mOperands;
1035       op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1036       if (mOperands.size() != 2)
1037         return failure();
1038       if (llvm::all_of(mOperands, [](auto pair) {
1039             return cast<ShapedType>(std::get<0>(pair).getType())
1040                        .getShape()[std::get<1>(pair)] == 1;
1041           })) {
1042         operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
1043                                                std::get<1>(mOperands[1])};
1044         return success();
1045       }
1046     } else {
1047       auto n = contractionDims.n[0];
1048       SmallVector<std::pair<Value, unsigned>, 2> nOperands;
1049       op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1050       if (nOperands.size() != 2)
1051         return failure();
1052       if (llvm::all_of(nOperands, [](auto pair) {
1053             return cast<ShapedType>(std::get<0>(pair).getType())
1054                        .getShape()[std::get<1>(pair)] == 1;
1055           })) {
1056         operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
1057                                                std::get<1>(nOperands[1])};
1058         return success();
1059       }
1060     }
1061     LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
1062     return failure();
1063   }
1064 };
1065 
1066 } // namespace
1067 
1068 void mlir::linalg::populateContractionOpRankReducingPatterns(
1069     RewritePatternSet &patterns) {
1070   MLIRContext *context = patterns.getContext();
1071   // Unbatching patterns for unit batch size
1072   patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1073   patterns
1074       .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
1075           context);
1076   patterns
1077       .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
1078           context);
1079   patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1080   patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1081 
1082   // Non-batch rank 1 reducing patterns
1083   patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1084   patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1085   patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
1086   patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
1087   // Batch rank 1 reducing patterns
1088   patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1089   patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1090   patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
1091       context);
1092   patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
1093       context);
1094 
1095   // Non-batch rank 0 reducing patterns
1096   patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1097   patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
1098 }
1099