xref: /llvm-project/mlir/lib/Dialect/Linalg/Utils/Utils.cpp (revision eabb6ccdc878b52facd526a74ec3e175e9850024)
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Utils/Utils.h"
14 
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/Arith/Utils/Utils.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Linalg/IR/Linalg.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h"
27 #include "mlir/Dialect/Tensor/Utils/Utils.h"
28 #include "mlir/Dialect/Utils/IndexingUtils.h"
29 #include "mlir/Dialect/Utils/StaticValueUtils.h"
30 #include "mlir/IR/AffineExpr.h"
31 #include "mlir/IR/AffineExprVisitor.h"
32 #include "mlir/IR/AffineMap.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/Pass/Pass.h"
36 #include "llvm/ADT/SetOperations.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 #include "llvm/Support/Debug.h"
39 #include <optional>
40 
41 #define DEBUG_TYPE "linalg-utils"
42 
43 using namespace mlir;
44 using namespace presburger;
45 using namespace mlir::linalg;
46 using namespace mlir::scf;
47 
48 namespace {
49 
50 // Helper visitor to determine whether an AffineExpr is tiled.
51 // This is achieved by traversing every AffineDimExpr with position `pos` and
52 // checking whether the corresponding `tileSizes[pos]` is non-zero.
53 // This also enforces only positive coefficients occur in multiplications.
54 //
55 // Example:
56 //   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
57 //
58 struct TileCheck : public AffineExprVisitor<TileCheck> {
59   TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
60 
61   void visitDimExpr(AffineDimExpr expr) {
62     isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
63   }
64   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
65     visit(expr.getLHS());
66     visit(expr.getRHS());
67     if (expr.getKind() == mlir::AffineExprKind::Mul)
68       assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
69              "nonpositive multiplying coefficient");
70   }
71   bool isTiled = false;
72   ArrayRef<OpFoldResult> tileSizes;
73 };
74 
75 } // namespace
76 
77 static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
78   if (!expr)
79     return false;
80   TileCheck t(tileSizes);
81   t.visit(expr);
82   return t.isTiled;
83 }
84 
85 // Checks whether the `map  varies with respect to a non-zero `tileSize`.
86 static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
87   if (!map)
88     return false;
89   for (unsigned r = 0; r < map.getNumResults(); ++r)
90     if (isTiled(map.getResult(r), tileSizes))
91       return true;
92   return false;
93 }
94 
95 std::optional<RegionMatcher::BinaryOpKind>
96 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
97   auto &region = op.getRegion();
98   if (!llvm::hasSingleElement(region))
99     return std::nullopt;
100 
101   Block &block = region.front();
102   if (block.getNumArguments() != 2 ||
103       !block.getArgument(0).getType().isSignlessIntOrFloat() ||
104       !block.getArgument(1).getType().isSignlessIntOrFloat())
105     return std::nullopt;
106 
107   auto &ops = block.getOperations();
108   if (!llvm::hasSingleElement(block.without_terminator()))
109     return std::nullopt;
110 
111   using mlir::matchers::m_Val;
112   auto a = m_Val(block.getArgument(0));
113   auto b = m_Val(block.getArgument(1));
114 
115   auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b));
116   if (addPattern.match(&ops.back()))
117     return BinaryOpKind::IAdd;
118 
119   return std::nullopt;
120 }
121 
122 /// Explicit instantiation of loop nest generator for different loop types.
123 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
124 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
125 template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
126 
127 /// Given a list of subview ranges, extract individual values for lower, upper
128 /// bounds and steps and put them into the corresponding vectors.
129 static void unpackRanges(OpBuilder &builder, Location loc,
130                          ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
131                          SmallVectorImpl<Value> &ubs,
132                          SmallVectorImpl<Value> &steps) {
133   for (Range range : ranges) {
134     lbs.emplace_back(
135         getValueOrCreateConstantIndexOp(builder, loc, range.offset));
136     ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size));
137     steps.emplace_back(
138         getValueOrCreateConstantIndexOp(builder, loc, range.stride));
139   }
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // Utilities for inferring various semantics properties of Linalg ops.
144 //===----------------------------------------------------------------------===//
145 
146 DenseSet<int64_t> mlir::linalg::findPermutationsIndexingOperand(
147     LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter) {
148   DenseSet<int64_t> res;
149   assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
150   AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
151   for (AffineExpr e : indexingMap.getResults()) {
152     if (auto d = e.dyn_cast<AffineDimExpr>()) {
153       if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
154           llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
155             return e.isFunctionOfDim(d.getPosition());
156           }) == 1)
157         res.insert(d.getPosition());
158     }
159   }
160   return res;
161 }
162 
163 namespace {
164 auto par = utils::IteratorType::parallel;
165 auto red = utils::IteratorType::reduction;
166 } // namespace
167 
168 bool mlir::linalg::containsMostMinorMatmul(LinalgOp linalgOp) {
169   FailureOr<EmbeddedMatmulDimsCandidates> res = inferMatmulDims(linalgOp);
170   if (failed(res))
171     return false;
172   int64_t numLoops = linalgOp.getNumLoops();
173   for (const DenseSet<int64_t> &s : {res->mPos, res->nPos, res->kPos}) {
174     if (s.contains(numLoops - 3) || s.contains(numLoops - 2) ||
175         s.contains(numLoops - 1))
176       continue;
177     return false;
178   }
179   return true;
180 }
181 
182 FailureOr<EmbeddedMatmulDimsCandidates>
183 mlir::linalg::inferMatmulDims(LinalgOp linalgOp) {
184   if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
185     return failure();
186 
187   DenseSet<int64_t> a = findPermutationsIndexingOperand(
188       linalgOp, linalgOp.getDpsInputOperand(0), par);
189   DenseSet<int64_t> b = findPermutationsIndexingOperand(
190       linalgOp, linalgOp.getDpsInputOperand(1), par);
191   DenseSet<int64_t> c = findPermutationsIndexingOperand(
192       linalgOp, linalgOp.getDpsInitOperand(0), par);
193 
194   // A & C - B are the iterators involved in an outer-product along A (the LHS).
195   DenseSet<int64_t> ac = a;
196   llvm::set_intersect(ac, c);
197   llvm::set_subtract(ac, b);
198   // B & C - A are the iterators involved in an outer-product along B (the RHS).
199   DenseSet<int64_t> bc = b;
200   llvm::set_intersect(bc, c);
201   llvm::set_subtract(bc, a);
202 
203   // Note: if we ever need them, A & B & C would be "batch" dimensions.
204 
205   // A & B red are the reduction dimensions.
206   DenseSet<int64_t> ra = findPermutationsIndexingOperand(
207       linalgOp, linalgOp.getDpsInputOperand(0), red);
208   DenseSet<int64_t> rb = findPermutationsIndexingOperand(
209       linalgOp, linalgOp.getDpsInputOperand(1), red);
210   llvm::set_intersect(ra, rb);
211 
212   if (ac.empty() || bc.empty() || ra.empty())
213     return failure();
214 
215   // Pick the first one in each set.
216   // TODO: Better heuristic (e.g pick dims based on packing-based metric).
217   return EmbeddedMatmulDimsCandidates{ac, bc, ra};
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // General utilities
222 //===----------------------------------------------------------------------===//
223 
224 namespace mlir {
225 namespace linalg {
226 
227 bool allIndexingsAreProjectedPermutation(LinalgOp op) {
228   return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
229     return m.isProjectedPermutation(/*allowZeroInResults=*/true);
230   });
231 }
232 
233 bool hasOnlyScalarElementwiseOp(Region &r) {
234   if (!llvm::hasSingleElement(r))
235     return false;
236   for (Operation &op : r.front()) {
237     if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
238               linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
239           OpTrait::hasElementwiseMappableTraits(&op)) ||
240         llvm::any_of(op.getResultTypes(),
241                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
242       return false;
243   }
244   return true;
245 }
246 
247 bool isElementwise(LinalgOp op) {
248   if (op.getNumLoops() != op.getNumParallelLoops())
249     return false;
250 
251   if (!allIndexingsAreProjectedPermutation(op))
252     return false;
253 
254   // TODO: relax the restrictions on indexing map.
255   for (OpOperand *opOperand : op.getDpsInitOperands()) {
256     if (!op.getMatchingIndexingMap(opOperand).isPermutation())
257       return false;
258   }
259   return hasOnlyScalarElementwiseOp(op->getRegion(0));
260 }
261 
262 bool isParallelIterator(utils::IteratorType iteratorType) {
263   return iteratorType == utils::IteratorType::parallel;
264 }
265 
266 bool isReductionIterator(utils::IteratorType iteratorType) {
267   return iteratorType == utils::IteratorType::reduction;
268 }
269 
270 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
271 /// the type of `source`.
272 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {
273   if (source.getType().isa<UnrankedMemRefType, MemRefType>())
274     return b.createOrFold<memref::DimOp>(loc, source, dim);
275   if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
276     return b.createOrFold<tensor::DimOp>(loc, source, dim);
277   llvm_unreachable("Expected MemRefType or TensorType");
278 }
279 
280 OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source,
281                                int64_t dim) {
282   auto shapedType = source.getType().cast<ShapedType>();
283   if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
284     return createOrFoldDimOp(b, loc, source, dim);
285   return b.getIndexAttr(shapedType.getDimSize(dim));
286 }
287 
288 /// Given an operation, retrieves the value of each dynamic dimension through
289 /// constructing the necessary DimOp operators.
290 SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b) {
291   SmallVector<Value, 4> dynOperands;
292   auto shapedType = val.getType().cast<ShapedType>();
293   for (const auto &dim : llvm::enumerate(shapedType.getShape())) {
294     if (dim.value() == ShapedType::kDynamic)
295       dynOperands.push_back(createOrFoldDimOp(b, loc, val, dim.index()));
296   }
297   return dynOperands;
298 }
299 
300 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
301                             Value source, Value pad, bool nofold) {
302   // Exit if `source` is not defined by an ExtractSliceOp.
303   auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
304   if (!sliceOp)
305     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
306 
307   // Search the `source` use-def chain for padded LinalgOps.
308   Value current = sliceOp.getSource();
309   while (current) {
310     auto linalgOp = current.getDefiningOp<LinalgOp>();
311     if (!linalgOp)
312       break;
313     OpResult opResult = current.cast<OpResult>();
314     current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
315   }
316   auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
317 
318   // Exit if the search fails to match a tensor::PadOp at the end of the matched
319   // LinalgOp sequence.
320   if (!padOp)
321     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
322 
323   // Exit if the padded result type does not match.
324   if (sliceOp.getSource().getType() != type)
325     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
326 
327   // Exit if the LinalgOps are not high padded.
328   if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
329         return getConstantIntValue(ofr) != static_cast<int64_t>(0);
330       }))
331     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
332 
333   // Exit if `padOpSliceOp`, which defines the slice used by
334   // `padOp`, is rank-reducing.
335   auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
336   if (!padOpSliceOp ||
337       sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
338     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
339 
340   // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
341   // of the slice padded by `padOp`.
342   if (llvm::any_of(
343           llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
344           [](std::tuple<OpFoldResult, OpFoldResult> it) {
345             return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
346           }))
347     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
348 
349   // Exit if the padding values do not match.
350   Attribute padOpPadAttr, padAttr;
351   Value padOpPad = padOp.getConstantPaddingValue();
352   if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) ||
353       !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr)
354     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
355 
356   // Return the padded result if the padding values and sizes match.
357   return sliceOp.getSource();
358 }
359 
360 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
361                           Value outputTensor,
362                           ArrayRef<int64_t> transposeVector) {
363   auto resultTensorType = outputTensor.getType().cast<RankedTensorType>();
364   Type elementType = resultTensorType.getElementType();
365 
366   assert(isPermutationVector(transposeVector) &&
367          "expect transpose vector to be a permutation");
368   assert(transposeVector.size() ==
369              static_cast<size_t>(resultTensorType.getRank()) &&
370          "expect transpose vector size to match result tensor rank");
371 
372   // Compute the transpose and the indentity indexing maps.
373   SmallVector<AffineMap> indexingMaps = {
374       inversePermutation(AffineMap::getPermutationMap(
375           SmallVector<unsigned>(transposeVector.begin(), transposeVector.end()),
376           b.getContext())),
377       AffineMap::getMultiDimIdentityMap(transposeVector.size(),
378                                         b.getContext())};
379   SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(),
380                                                  utils::IteratorType::parallel);
381 
382   // Create a GenericOp to transpose `inputTensor` into `outputTensor`.
383   auto transposeOp =
384       b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor,
385                           indexingMaps, iteratorTypes);
386   Region &body = transposeOp.getRegion();
387   body.push_back(new Block());
388   body.front().addArguments({elementType, elementType}, {loc, loc});
389 
390   // Create the body of the transpose operation.
391   OpBuilder::InsertionGuard g(b);
392   b.setInsertionPointToEnd(&body.front());
393   b.create<YieldOp>(loc, transposeOp.getRegion().front().getArgument(0));
394   return transposeOp;
395 }
396 
397 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
398   auto memrefTypeTo = to.getType().cast<MemRefType>();
399 #ifndef NDEBUG
400   auto memrefTypeFrom = from.getType().cast<MemRefType>();
401   assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
402          "`from` and `to` memref must have the same rank");
403 #endif // NDEBUG
404 
405   AffineMap id =
406       AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
407   SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
408                                                  utils::IteratorType::parallel);
409   return b.create<linalg::GenericOp>(
410       loc,
411       /*inputs=*/from,
412       /*outputs=*/to,
413       /*indexingMaps=*/llvm::ArrayRef({id, id}),
414       /*iteratorTypes=*/iteratorTypes,
415       [](OpBuilder &b, Location loc, ValueRange args) {
416         b.create<linalg::YieldOp>(loc, args.front());
417       });
418 }
419 
420 /// Specialization to build an scf "for" nest.
421 template <>
422 void GenerateLoopNest<scf::ForOp>::doit(
423     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
424     ArrayRef<utils::IteratorType> iteratorTypes,
425     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
426                                   ValueRange)>
427         bodyBuilderFn,
428     ArrayRef<linalg::ProcInfo> procInfo) {
429   assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
430          "expected as many entries for proc info as number of loops, even if "
431          "they are null entries");
432   SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
433                                              ? SmallVector<Value>{}
434                                              : linalgOp.getDpsInitOperands();
435 
436   SmallVector<Value, 4> lbs, ubs, steps;
437   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
438   LoopNest loopNest = mlir::scf::buildLoopNest(
439       b, loc, lbs, ubs, steps, iterArgInitValues,
440       [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
441         assert(iterArgs.size() == iterArgInitValues.size() &&
442                "expect the number of output tensors and iter args to match");
443         SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
444         if (!iterArgs.empty()) {
445           operandValuesToUse = linalgOp.getDpsInputOperands();
446           operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
447         }
448         return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
449       });
450 
451   if (loopNest.loops.empty() || procInfo.empty())
452     return;
453 
454   // Filter out scf.for loops that were created out of parallel dimensions.
455   for (const auto &loop : llvm::enumerate(loopNest.loops)) {
456     if (procInfo[loop.index()].distributionMethod ==
457         DistributionMethod::Cyclic) {
458       mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
459                             procInfo[loop.index()].nprocs);
460     }
461   }
462 }
463 
464 /// Specialization to build affine "for" nest.
465 template <>
466 void GenerateLoopNest<AffineForOp>::doit(
467     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
468     ArrayRef<utils::IteratorType> iteratorTypes,
469     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
470                                   ValueRange)>
471         bodyBuilderFn,
472     ArrayRef<linalg::ProcInfo> /*procInfo*/) {
473   SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
474                                              ? SmallVector<Value>{}
475                                              : linalgOp.getDpsInitOperands();
476   assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
477   SmallVector<Value, 4> lbs, ubs, steps;
478   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
479 
480   // Affine loops require constant steps.
481   SmallVector<int64_t, 4> constantSteps;
482   constantSteps.reserve(steps.size());
483   for (Value v : steps) {
484     auto op = v.getDefiningOp<arith::ConstantIndexOp>();
485     assert(op && "Affine loops require constant steps");
486     constantSteps.push_back(op.value());
487   }
488 
489   mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
490                             [&](OpBuilder &b, Location loc, ValueRange ivs) {
491                               bodyBuilderFn(b, loc, ivs,
492                                             linalgOp->getOperands());
493                             });
494 }
495 
496 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
497 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
498                                        Value nprocs, Value &lb, Value &ub,
499                                        Value &step) {
500   AffineExpr d0, d1;
501   bindDims(b.getContext(), d0, d1);
502   AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
503   lb = makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
504   step = makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
505 }
506 
507 /// Generates a loop nest consisting of scf.parallel and scf.for, depending
508 /// on the `iteratorTypes.` Consecutive parallel loops create a single
509 /// scf.parallel operation; each sequential loop creates a new scf.for
510 /// operation. The body of the innermost loop is populated by
511 /// `bodyBuilderFn` that accepts a range of induction variables for all
512 /// loops. `ivStorage` is used to store the partial list of induction
513 /// variables.
514 // TODO: this function can be made iterative instead. However, it
515 // will have at most as many recursive calls as nested loops, which rarely
516 // exceeds 10.
517 static void generateParallelLoopNest(
518     OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
519     ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
520     ArrayRef<linalg::ProcInfo> procInfo,
521     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
522     SmallVectorImpl<Value> &ivStorage) {
523   assert(lbs.size() == ubs.size());
524   assert(lbs.size() == steps.size());
525   assert(lbs.size() == iteratorTypes.size());
526   assert(procInfo.empty() || (lbs.size() == procInfo.size()));
527 
528   // If there are no (more) loops to be generated, generate the body and be
529   // done with it.
530   if (iteratorTypes.empty()) {
531     bodyBuilderFn(b, loc, ivStorage);
532     return;
533   }
534 
535   // If there are no outer parallel loops, generate one sequential loop and
536   // recurse.
537   if (!isParallelIterator(iteratorTypes.front())) {
538     LoopNest singleLoop = buildLoopNest(
539         b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
540         [&](OpBuilder &b, Location loc, ValueRange ivs) {
541           ivStorage.append(ivs.begin(), ivs.end());
542           generateParallelLoopNest(
543               b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
544               iteratorTypes.drop_front(),
545               procInfo.empty() ? procInfo : procInfo.drop_front(),
546               bodyBuilderFn, ivStorage);
547         });
548     return;
549   }
550 
551   unsigned nLoops = iteratorTypes.size();
552   unsigned numProcessed = 0;
553   DistributionMethod distributionMethod = DistributionMethod::None;
554   if (procInfo.empty()) {
555     numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size();
556   } else {
557     distributionMethod = procInfo.front().distributionMethod;
558     numProcessed =
559         nLoops - procInfo
560                      .drop_while([&](linalg::ProcInfo p) {
561                        return p.distributionMethod == distributionMethod;
562                      })
563                      .size();
564   }
565 
566   auto remainderProcInfo =
567       procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
568   switch (distributionMethod) {
569   case DistributionMethod::None: {
570     // Generate a single parallel loop-nest operation for all outermost
571     // parallel loops and recurse.
572     b.create<scf::ParallelOp>(
573         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
574         steps.take_front(numProcessed),
575         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
576           ivStorage.append(localIvs.begin(), localIvs.end());
577           generateParallelLoopNest(
578               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
579               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
580               iteratorTypes.drop_front(numProcessed), remainderProcInfo,
581               bodyBuilderFn, ivStorage);
582         });
583     return;
584   }
585   case DistributionMethod::Cyclic: {
586     // Generate a single parallel loop-nest operation for all outermost
587     // parallel loops and recurse.
588     b.create<scf::ParallelOp>(
589         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
590         steps.take_front(numProcessed),
591         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
592           ivStorage.append(localIvs.begin(), localIvs.end());
593           generateParallelLoopNest(
594               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
595               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
596               iteratorTypes.drop_front(numProcessed), remainderProcInfo,
597               bodyBuilderFn, ivStorage);
598         });
599     return;
600   }
601   case DistributionMethod::CyclicNumProcsGeNumIters: {
602     // Check (for the processed loops) that the iteration is in-bounds.
603     ArithBuilder ab(b, loc);
604     Value cond = ab.slt(lbs[0], ubs[0]);
605     for (unsigned i = 1; i < numProcessed; ++i)
606       cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
607     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
608     b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) {
609       generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed),
610                                ubs.drop_front(numProcessed),
611                                steps.drop_front(numProcessed),
612                                iteratorTypes.drop_front(numProcessed),
613                                remainderProcInfo, bodyBuilderFn, ivStorage);
614       b.create<scf::YieldOp>(loc, ValueRange{});
615     });
616     return;
617   }
618   case DistributionMethod::CyclicNumProcsEqNumIters:
619     // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
620     // with inner loop generation.
621     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
622     generateParallelLoopNest(
623         b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
624         steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
625         remainderProcInfo, bodyBuilderFn, ivStorage);
626     return;
627   }
628 }
629 
630 /// Specialization for generating a mix of parallel and sequential scf loops.
631 template <>
632 void GenerateLoopNest<scf::ParallelOp>::doit(
633     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
634     ArrayRef<utils::IteratorType> iteratorTypes,
635     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
636                                   ValueRange)>
637         bodyBuilderFn,
638     ArrayRef<linalg::ProcInfo> procInfo) {
639   SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
640                                              ? SmallVector<Value>{}
641                                              : linalgOp.getDpsInitOperands();
642   assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
643   // This function may be passed more iterator types than ranges.
644   assert(iteratorTypes.size() >= loopRanges.size() &&
645          "expected iterator type for all ranges");
646   assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
647          "expected proc information for all loops when present");
648   iteratorTypes = iteratorTypes.take_front(loopRanges.size());
649   SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
650   unsigned numLoops = iteratorTypes.size();
651   ivs.reserve(numLoops);
652   lbsStorage.reserve(numLoops);
653   ubsStorage.reserve(numLoops);
654   stepsStorage.reserve(numLoops);
655 
656   // Get the loop lb, ub, and step.
657   unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
658 
659   // Modify the lb, ub, and step based on the distribution options.
660   for (const auto &it : llvm::enumerate(procInfo)) {
661     if (it.value().distributionMethod != linalg::DistributionMethod::None) {
662       updateBoundsForCyclicDistribution(
663           b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
664           ubsStorage[it.index()], stepsStorage[it.index()]);
665     }
666   }
667   ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
668   generateParallelLoopNest(
669       b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
670       [&](OpBuilder &b, Location loc, ValueRange ivs) {
671         bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
672       },
673       ivs);
674 
675   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
676 }
677 
678 static Value materializeTiledShape(OpBuilder &builder, Location loc,
679                                    Value valueToTile,
680                                    const SliceParameters &sliceParams) {
681   auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
682   auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
683                       .Case([&](MemRefType) {
684                         return builder.create<memref::SubViewOp>(
685                             loc, valueToTile, sliceParams.offsets,
686                             sliceParams.sizes, sliceParams.strides);
687                       })
688                       .Case([&](RankedTensorType) {
689                         return builder.create<tensor::ExtractSliceOp>(
690                             loc, valueToTile, sliceParams.offsets,
691                             sliceParams.sizes, sliceParams.strides);
692                       })
693                       .Default([](ShapedType) -> Operation * {
694                         llvm_unreachable("Unexpected shaped type");
695                       });
696   return sliceOp->getResult(0);
697 }
698 
699 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
700                      ArrayRef<OpFoldResult> tileSizes, AffineMap map,
701                      ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
702                      ArrayRef<OpFoldResult> subShapeSizes,
703                      bool omitPartialTileCheck) {
704   SliceParameters sliceParams =
705       computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
706                              ubs, subShapeSizes, omitPartialTileCheck);
707   return materializeTiledShape(builder, loc, valueToTile, sliceParams);
708 }
709 
710 SliceParameters
711 computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
712                        ArrayRef<OpFoldResult> tileSizes, AffineMap map,
713                        ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
714                        ArrayRef<OpFoldResult> subShapeSizes,
715                        bool omitPartialTileCheck) {
716   auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
717   assert(shapedType && "only shaped types can be tiled");
718   ArrayRef<int64_t> shape = shapedType.getShape();
719   int64_t rank = shapedType.getRank();
720 
721   // Compute offsets/sizes/strides for the tile.
722   SliceParameters sliceParams;
723   sliceParams.offsets.reserve(rank);
724   sliceParams.sizes.reserve(rank);
725   sliceParams.strides.reserve(rank);
726   for (unsigned r = 0; r < rank; ++r) {
727     LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
728     if (!isTiled(map.getSubMap({r}), tileSizes)) {
729       sliceParams.offsets.push_back(builder.getIndexAttr(0));
730       OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
731       sliceParams.sizes.push_back(dim);
732       sliceParams.strides.push_back(builder.getIndexAttr(1));
733       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
734       continue;
735     }
736     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
737 
738     // Tiling creates a new slice at the proper index, the slice step is 1
739     // (i.e. the op does not subsample, stepping occurs in the loop).
740     auto m = map.getSubMap({r});
741     LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
742     IRRewriter rewriter(builder);
743     OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs);
744     sliceParams.offsets.push_back(offset);
745     OpFoldResult closedIntSize =
746         makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
747     // Resulting size needs to be made half open interval again.
748     AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
749     OpFoldResult size =
750         makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize);
751     LLVM_DEBUG(llvm::dbgs()
752                << "computeSliceParameters: raw size: " << size << "\n");
753     LLVM_DEBUG(llvm::dbgs()
754                << "computeSliceParameters: new offset: " << offset << "\n");
755     sliceParams.strides.push_back(builder.getIndexAttr(1));
756 
757     if (omitPartialTileCheck) {
758       // We statically know that the partial/boundary tile condition is
759       // unnecessary.
760       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
761       sliceParams.sizes.push_back(size);
762       continue;
763     }
764 
765     // The size of the subview / extract_slice should be trimmed to avoid
766     // out-of-bounds accesses, unless:
767     // a. We statically know the subshape size divides the shape size evenly.
768     // b. The subshape size is 1. According to the way the loops are set up,
769     //    tensors with "0" dimensions would never be constructed.
770     int64_t shapeSize = shape[r];
771     std::optional<int64_t> sizeCst = getConstantIntValue(size);
772     auto hasTileSizeOne = sizeCst && *sizeCst == 1;
773     auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
774                          ((shapeSize % *sizeCst) == 0);
775     if (!hasTileSizeOne && !dividesEvenly) {
776       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
777                               << ", size: " << size
778                               << ": make sure in bound with affine.min\n");
779 
780       AffineExpr dim0, dim1, dim2;
781       bindDims(builder.getContext(), dim0, dim1, dim2);
782 
783       // Get the dimension size for this dimension. We need to first calculate
784       // the max index and then plus one. This is important because for
785       // convolution ops, we have its input window dimension's affine map of the
786       // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
787       // dimension and `s0` is stride. Directly use the dimension size of
788       // output/filer window dimensions will cause incorrect calculation.
789       AffineMap minusOneMap =
790           AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 - 1}})
791               .front();
792       AffineMap plusOneMap =
793           AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 + 1}})
794               .front();
795       SmallVector<OpFoldResult> maxIndices =
796           llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) {
797             return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap,
798                                                  {ub});
799           }));
800       OpFoldResult maxIndex =
801           makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices);
802       OpFoldResult d =
803           makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex});
804 
805       // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
806       AffineMap minMap = AffineMap::inferFromExprList(
807                              {ArrayRef<AffineExpr>{dim1 - dim2, dim0}})
808                              .front();
809       size =
810           makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});
811     }
812     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
813     sliceParams.sizes.push_back(size);
814   }
815   return sliceParams;
816 }
817 
818 SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
819                                              ArrayRef<OpFoldResult> ivs,
820                                              ArrayRef<OpFoldResult> tileSizes) {
821   SmallVector<OpFoldResult> offsets;
822   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
823     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
824     bool isTiled = !isZeroIndex(tileSizes[idx]);
825     offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
826     LLVM_DEBUG(llvm::dbgs()
827                << "computeTileOffsets: " << offsets.back() << "\n");
828   }
829   return offsets;
830 }
831 
832 SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
833                                            ArrayRef<OpFoldResult> tileSizes,
834                                            ArrayRef<OpFoldResult> sizeBounds) {
835   SmallVector<OpFoldResult> sizes;
836   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
837     bool isTiled = !isZeroIndex(tileSizes[idx]);
838     // Before composing, we need to make range a closed interval.
839     OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
840     AffineExpr d0 = getAffineDimExpr(0, b.getContext());
841     IRRewriter rewriter(b);
842     sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size));
843     LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
844   }
845   return sizes;
846 }
847 
848 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
849   if (op.hasBufferSemantics())
850     return {};
851   return llvm::to_vector(
852       llvm::map_range(op.getDpsInitOperands(), [&](OpOperand *opOperand) {
853         return operands[opOperand->getOperandNumber()].getType();
854       }));
855 }
856 
857 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
858                                     LinalgOp op, ValueRange operands,
859                                     ValueRange results) {
860   if (op.hasBufferSemantics())
861     return {};
862   SmallVector<Value> tensorResults;
863   tensorResults.reserve(results.size());
864   // Insert a insert_slice for each output tensor.
865   unsigned resultIdx = 0;
866   for (OpOperand *opOperand : op.getDpsInitOperands()) {
867     // TODO: use an interface/adaptor to avoid leaking position in
868     // `tiledOperands`.
869     Value outputTensor = operands[opOperand->getOperandNumber()];
870     if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
871       Value inserted = builder.create<tensor::InsertSliceOp>(
872           loc, sliceOp.getSource().getType(), results[resultIdx],
873           sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
874           sliceOp.getStrides(), sliceOp.getStaticOffsets(),
875           sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
876       tensorResults.push_back(inserted);
877     } else {
878       tensorResults.push_back(results[resultIdx]);
879     }
880     ++resultIdx;
881   }
882   return tensorResults;
883 }
884 
885 SmallVector<std::optional<SliceParameters>>
886 computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
887                           ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
888                           ArrayRef<OpFoldResult> tileSizes,
889                           ArrayRef<OpFoldResult> sizeBounds,
890                           bool omitPartialTileCheck) {
891   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
892                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
893                            [](OpFoldResult v) { return !isZeroIndex(v); })) &&
894          "expected as many ivs as non-zero sizes");
895 
896   // Construct (potentially temporary) mins and maxes on which to apply maps
897   // that define tile subshapes.
898   SmallVector<OpFoldResult> lbs =
899       computeTileOffsets(builder, loc, ivs, tileSizes);
900   SmallVector<OpFoldResult> subShapeSizes =
901       computeTileSizes(builder, loc, tileSizes, sizeBounds);
902 
903   assert(static_cast<int64_t>(valuesToTile.size()) <=
904              linalgOp->getNumOperands() &&
905          "more value to tile than operands.");
906   SmallVector<std::optional<SliceParameters>> allSliceParams;
907   allSliceParams.reserve(valuesToTile.size());
908   for (auto [opOperand, val] :
909        llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
910     Value shapedOp = val;
911     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
912     AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
913     // Use `opOperand` as is if it is not tiled and not an output tensor. Having
914     // an extract/insert slice pair for all output tensors simplifies follow up
915     // transformations such as padding and bufferization since the
916     // extract/insert slice pairs make the accessed iteration argument
917     // subdomains explicit.
918 
919     Type operandType = opOperand.get().getType();
920     if (!isTiled(map, tileSizes) && !(operandType.isa<RankedTensorType>() &&
921                                       linalgOp.isDpsInit(&opOperand))) {
922       allSliceParams.push_back(std::nullopt);
923       LLVM_DEBUG(llvm::dbgs()
924                  << ": not tiled: use shape: " << operandType << "\n");
925       continue;
926     }
927     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
928 
929     allSliceParams.push_back(computeSliceParameters(
930         builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
931         omitPartialTileCheck));
932   }
933 
934   return allSliceParams;
935 }
936 
937 SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
938                                    LinalgOp linalgOp, ValueRange valuesToTile,
939                                    ArrayRef<OpFoldResult> ivs,
940                                    ArrayRef<OpFoldResult> tileSizes,
941                                    ArrayRef<OpFoldResult> sizeBounds,
942                                    bool omitPartialTileCheck) {
943   SmallVector<std::optional<SliceParameters>> allSliceParameter =
944       computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
945                                 tileSizes, sizeBounds, omitPartialTileCheck);
946   SmallVector<Value> tiledShapes;
947   for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
948     Value valueToTile = std::get<0>(item);
949     std::optional<SliceParameters> sliceParams = std::get<1>(item);
950     tiledShapes.push_back(
951         sliceParams.has_value()
952             ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
953             : valueToTile);
954   }
955   return tiledShapes;
956 }
957 
958 void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
959                    ArrayRef<OpFoldResult> offsets) {
960   IRRewriter rewriter(b);
961   offsetIndices(rewriter, linalgOp, offsets);
962 }
963 
964 void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
965                    ArrayRef<OpFoldResult> offsets) {
966   if (!linalgOp.hasIndexSemantics())
967     return;
968 
969   for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
970     if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
971       continue;
972     OpBuilder::InsertionGuard guard(b);
973     b.setInsertionPointAfter(indexOp);
974     AffineExpr index, offset;
975     bindDims(b.getContext(), index, offset);
976     OpFoldResult applied = makeComposedFoldedAffineApply(
977         b, indexOp.getLoc(), index + offset,
978         {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
979     Value materialized =
980         getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
981     b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) {
982       return use.getOwner() != materialized.getDefiningOp();
983     });
984   }
985 }
986 
987 /// Get the reassociation maps to fold the result of a extract_slice (or source
988 /// of a insert_slice) operation with given offsets, and sizes to its
989 /// rank-reduced version. This is only done for the cases where the size is 1
990 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
991 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
992 /// potentially cannot be handled).
993 std::optional<SmallVector<ReassociationIndices>>
994 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
995   SmallVector<ReassociationIndices> reassociation;
996   ReassociationIndices curr;
997   for (const auto &it : llvm::enumerate(mixedSizes)) {
998     auto dim = it.index();
999     auto size = it.value();
1000     curr.push_back(dim);
1001     auto attr = size.dyn_cast<Attribute>();
1002     if (attr && attr.cast<IntegerAttr>().getInt() == 1)
1003       continue;
1004     reassociation.emplace_back(ReassociationIndices{});
1005     std::swap(reassociation.back(), curr);
1006   }
1007   // When the reassociations are not empty, then fold the remaining
1008   // unit-dimensions into the last dimension.  If the reassociations so far is
1009   // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
1010   if (!curr.empty() && !reassociation.empty())
1011     reassociation.back().append(curr.begin(), curr.end());
1012   return reassociation;
1013 }
1014 
1015 /// Return the identity numeric value associated to the give op.
1016 std::optional<Attribute> getNeutralElement(Operation *op) {
1017   // Builder only used as helper for attribute creation.
1018   OpBuilder b(op->getContext());
1019   Type resultType = op->getResult(0).getType();
1020   if (auto floatType = resultType.dyn_cast<FloatType>()) {
1021     const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
1022     if (isa<arith::AddFOp>(op))
1023       return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
1024     if (isa<arith::MulFOp>(op))
1025       return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
1026     if (isa<arith::MaxFOp>(op))
1027       return b.getFloatAttr(resultType,
1028                             llvm::APFloat::getInf(semantic, /*Negative=*/true));
1029     if (isa<arith::MinFOp>(op))
1030       return b.getFloatAttr(
1031           resultType, llvm::APFloat::getInf(semantic, /*Negative=*/false));
1032     return std::nullopt;
1033   }
1034   if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
1035     return b.getIntegerAttr(resultType, 0);
1036   if (isa<arith::AndIOp>(op))
1037     return b.getIntegerAttr(resultType, -1);
1038   if (isa<arith::MaxSIOp>(op))
1039     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
1040   if (isa<arith::MinSIOp>(op))
1041     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
1042   if (isa<arith::MulIOp>(op))
1043     return b.getIntegerAttr(resultType, 1);
1044   return std::nullopt;
1045 }
1046 
1047 } // namespace linalg
1048 } // namespace mlir
1049