xref: /llvm-project/mlir/lib/Dialect/Linalg/Utils/Utils.cpp (revision 96179dff46a9a3981708b06bc9e0f981be4cc1a8)
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Utils/Utils.h"
14 
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/Arith/Utils/Utils.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Linalg/IR/Linalg.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h"
27 #include "mlir/Dialect/Tensor/Utils/Utils.h"
28 #include "mlir/Dialect/Utils/IndexingUtils.h"
29 #include "mlir/Dialect/Utils/StaticValueUtils.h"
30 #include "mlir/IR/AffineExpr.h"
31 #include "mlir/IR/AffineExprVisitor.h"
32 #include "mlir/IR/AffineMap.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/Pass/Pass.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/Debug.h"
38 #include <optional>
39 
40 #define DEBUG_TYPE "linalg-utils"
41 
42 using namespace mlir;
43 using namespace presburger;
44 using namespace mlir::linalg;
45 using namespace mlir::scf;
46 
47 namespace {
48 
49 // Helper visitor to determine whether an AffineExpr is tiled.
50 // This is achieved by traversing every AffineDimExpr with position `pos` and
51 // checking whether the corresponding `tileSizes[pos]` is non-zero.
52 // This also enforces only positive coefficients occur in multiplications.
53 //
54 // Example:
55 //   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
56 //
57 struct TileCheck : public AffineExprVisitor<TileCheck> {
58   TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
59 
60   void visitDimExpr(AffineDimExpr expr) {
61     isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
62   }
63   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
64     visit(expr.getLHS());
65     visit(expr.getRHS());
66     if (expr.getKind() == mlir::AffineExprKind::Mul)
67       assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
68              "nonpositive multiplying coefficient");
69   }
70   bool isTiled = false;
71   ArrayRef<OpFoldResult> tileSizes;
72 };
73 
74 } // namespace
75 
76 static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
77   if (!expr)
78     return false;
79   TileCheck t(tileSizes);
80   t.visit(expr);
81   return t.isTiled;
82 }
83 
84 // Checks whether the `map  varies with respect to a non-zero `tileSize`.
85 static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
86   if (!map)
87     return false;
88   for (unsigned r = 0; r < map.getNumResults(); ++r)
89     if (isTiled(map.getResult(r), tileSizes))
90       return true;
91   return false;
92 }
93 
94 std::optional<RegionMatcher::BinaryOpKind>
95 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
96   auto &region = op.getRegion();
97   if (!llvm::hasSingleElement(region))
98     return std::nullopt;
99 
100   Block &block = region.front();
101   if (block.getNumArguments() != 2 ||
102       !block.getArgument(0).getType().isSignlessIntOrFloat() ||
103       !block.getArgument(1).getType().isSignlessIntOrFloat())
104     return std::nullopt;
105 
106   auto &ops = block.getOperations();
107   if (!llvm::hasSingleElement(block.without_terminator()))
108     return std::nullopt;
109 
110   using mlir::matchers::m_Val;
111   auto a = m_Val(block.getArgument(0));
112   auto b = m_Val(block.getArgument(1));
113 
114   auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b));
115   if (addPattern.match(&ops.back()))
116     return BinaryOpKind::IAdd;
117 
118   return std::nullopt;
119 }
120 
121 /// Explicit instantiation of loop nest generator for different loop types.
122 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
123 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
124 template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
125 
126 /// Given a list of subview ranges, extract individual values for lower, upper
127 /// bounds and steps and put them into the corresponding vectors.
128 static void unpackRanges(OpBuilder &builder, Location loc,
129                          ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
130                          SmallVectorImpl<Value> &ubs,
131                          SmallVectorImpl<Value> &steps) {
132   for (Range range : ranges) {
133     lbs.emplace_back(
134         getValueOrCreateConstantIndexOp(builder, loc, range.offset));
135     ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size));
136     steps.emplace_back(
137         getValueOrCreateConstantIndexOp(builder, loc, range.stride));
138   }
139 }
140 
141 namespace mlir {
142 namespace linalg {
143 
144 bool allIndexingsAreProjectedPermutation(LinalgOp op) {
145   return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
146     return m.isProjectedPermutation(/*allowZeroInResults=*/true);
147   });
148 }
149 
150 bool hasOnlyScalarElementwiseOp(Region &r) {
151   if (!llvm::hasSingleElement(r))
152     return false;
153   for (Operation &op : r.front()) {
154     if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
155               linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
156           OpTrait::hasElementwiseMappableTraits(&op)) ||
157         llvm::any_of(op.getResultTypes(),
158                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
159       return false;
160   }
161   return true;
162 }
163 
164 bool isElementwise(LinalgOp op) {
165   if (op.getNumLoops() != op.getNumParallelLoops())
166     return false;
167 
168   if (!allIndexingsAreProjectedPermutation(op))
169     return false;
170 
171   // TODO: relax the restrictions on indexing map.
172   for (OpOperand *opOperand : op.getDpsInitOperands()) {
173     if (!op.getMatchingIndexingMap(opOperand).isPermutation())
174       return false;
175   }
176   return hasOnlyScalarElementwiseOp(op->getRegion(0));
177 }
178 
179 bool isParallelIterator(utils::IteratorType iteratorType) {
180   return iteratorType == utils::IteratorType::parallel;
181 }
182 
183 bool isReductionIterator(utils::IteratorType iteratorType) {
184   return iteratorType == utils::IteratorType::reduction;
185 }
186 
187 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
188 /// the type of `source`.
189 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {
190   if (source.getType().isa<UnrankedMemRefType, MemRefType>())
191     return b.createOrFold<memref::DimOp>(loc, source, dim);
192   if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
193     return b.createOrFold<tensor::DimOp>(loc, source, dim);
194   llvm_unreachable("Expected MemRefType or TensorType");
195 }
196 
197 OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source,
198                                int64_t dim) {
199   auto shapedType = source.getType().cast<ShapedType>();
200   if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
201     return createOrFoldDimOp(b, loc, source, dim);
202   return b.getIndexAttr(shapedType.getDimSize(dim));
203 }
204 
205 /// Given an operation, retrieves the value of each dynamic dimension through
206 /// constructing the necessary DimOp operators.
207 SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b) {
208   SmallVector<Value, 4> dynOperands;
209   auto shapedType = val.getType().cast<ShapedType>();
210   for (const auto &dim : llvm::enumerate(shapedType.getShape())) {
211     if (dim.value() == ShapedType::kDynamic)
212       dynOperands.push_back(createOrFoldDimOp(b, loc, val, dim.index()));
213   }
214   return dynOperands;
215 }
216 
217 void getUpperBoundForIndex(Value value, AffineMap &boundMap,
218                            SmallVectorImpl<Value> &boundOperands,
219                            bool constantRequired) {
220   // Initialize `boundMap` and `boundOperands` to the identity returning
221   // `value`. This combination is the default result of the method if no
222   // simplification is possible.
223   assert(value.getType().isIndex() && "expect value to have index type");
224   boundMap = AffineMap::getMultiDimIdentityMap(1, value.getContext());
225   boundOperands.assign({value});
226   canonicalizeMapAndOperands(&boundMap, &boundOperands);
227 
228   // Continue only if there is an affine index computation to simplify.
229   Operation *definingOp = value.getDefiningOp();
230   if (!definingOp || !isa<AffineApplyOp, AffineMinOp>(definingOp))
231     return;
232 
233   // Get the backward slice containing the affine index computation.
234   SetVector<Operation *> backwardSlice;
235   getBackwardSlice(definingOp, &backwardSlice, [](Operation *op) {
236     return isa<AffineApplyOp, AffineMinOp>(op);
237   });
238   backwardSlice.insert(definingOp);
239 
240   // Setup a system of affine constraints that describe the index computation.
241   FlatAffineValueConstraints constraints;
242 
243   // Helper to find or create an identifier for the given value.
244   auto findOrCreateId = [&](Value value) {
245     if (!constraints.containsVar(value)) {
246       constraints.appendDimVar(value);
247       return true;
248     }
249     unsigned pos;
250     constraints.findVar(value, &pos);
251     return pos < constraints.getNumDimVars();
252   };
253   // Helper to get the position for the given value.
254   auto getPosition = [&](Value value) {
255     unsigned pos;
256     bool exists = constraints.findVar(value, &pos);
257     (void)exists;
258     assert(exists && "expect to find the identifier");
259     return pos;
260   };
261 
262   // Add the affine operations in `backwardSlice` to the constraints.
263   for (Operation *op : llvm::reverse(backwardSlice)) {
264     // Add an identifier for all op results and operands.
265     if (!(llvm::all_of(op->getResults(), findOrCreateId) &&
266           llvm::all_of(op->getOperands(), findOrCreateId)))
267       return;
268 
269     // Add AffineApplyOps to the constraints.
270     if (auto applyOp = dyn_cast<AffineApplyOp>(op)) {
271       AffineMap map = constraints.computeAlignedMap(applyOp.getAffineMap(),
272                                                     applyOp.getOperands());
273       if (failed(constraints.addBound(IntegerPolyhedron::EQ,
274                                       getPosition(applyOp.getResult()), map)))
275         return;
276       continue;
277     }
278     // Add AffineMinOps to the constraints.
279     auto minOp = cast<AffineMinOp>(op);
280     AffineMap map = constraints.computeAlignedMap(minOp.getAffineMap(),
281                                                   minOp.getOperands());
282     if (failed(constraints.addBound(IntegerPolyhedron::UB,
283                                     getPosition(minOp.getResult()), map,
284                                     /*isClosedBound=*/true)))
285       return;
286   }
287 
288   // Obtain an upper bound for the affine index computation by projecting out
289   // all temporary results and expressing the upper bound for `value` in terms
290   // of the terminals of the index computation.
291   unsigned pos = getPosition(value);
292   if (constantRequired) {
293     auto ubConst = constraints.getConstantBound64(
294         FlatAffineValueConstraints::BoundType::UB, pos);
295     if (!ubConst)
296       return;
297 
298     boundMap = AffineMap::getConstantMap(*ubConst, value.getContext());
299     return;
300   }
301 
302   SmallVector<AffineMap> lowerBounds(1), upperBounds(1);
303   constraints.getSliceBounds(pos, 1, value.getContext(), &lowerBounds,
304                              &upperBounds,
305                              /*getClosedUB=*/true);
306   // Verify `upperBounds[0]` is valid and has at least one result.
307   if (!upperBounds[0] || upperBounds[0].getNumResults() == 0)
308     return;
309 
310   // Set `boundMap` and `boundOperands` to the computed upper bound.
311   boundMap = upperBounds[0];
312   constraints.getAllValues(&boundOperands);
313   erase_value(boundOperands, value);
314   canonicalizeMapAndOperands(&boundMap, &boundOperands);
315 }
316 
317 FailureOr<int64_t> getConstantUpperBoundForIndex(Value value) {
318   // Compute an upper bound for `value`.
319   AffineMap boundMap;
320   SmallVector<Value> boundOperands;
321   getUpperBoundForIndex(value, boundMap, boundOperands,
322                         /*constantRequired=*/true);
323 
324   // Search the results of `boundMap` for constant upper bounds.
325   SmallVector<int64_t> constantBounds;
326   for (AffineExpr result : boundMap.getResults())
327     if (auto constExpr = result.dyn_cast<AffineConstantExpr>())
328       constantBounds.push_back(constExpr.getValue());
329 
330   // Return the minimal upper bound or failure if none is found.
331   if (constantBounds.empty())
332     return failure();
333   return *std::min_element(constantBounds.begin(), constantBounds.end());
334 }
335 
336 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
337                             Value source, Value pad, bool nofold) {
338   // Exit if `source` is not defined by an ExtractSliceOp.
339   auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
340   if (!sliceOp)
341     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
342 
343   // Search the `source` use-def chain for padded LinalgOps.
344   Value current = sliceOp.getSource();
345   while (current) {
346     auto linalgOp = current.getDefiningOp<LinalgOp>();
347     if (!linalgOp)
348       break;
349     OpResult opResult = current.cast<OpResult>();
350     current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
351   }
352   auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
353 
354   // Exit if the search fails to match a tensor::PadOp at the end of the matched
355   // LinalgOp sequence.
356   if (!padOp)
357     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
358 
359   // Exit if the padded result type does not match.
360   if (sliceOp.getSource().getType() != type)
361     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
362 
363   // Exit if the LinalgOps are not high padded.
364   if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
365         return getConstantIntValue(ofr) != static_cast<int64_t>(0);
366       }))
367     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
368 
369   // Exit if `padOpSliceOp`, which defines the slice used by
370   // `padOp`, is rank-reducing.
371   auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
372   if (!padOpSliceOp ||
373       sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
374     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
375 
376   // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
377   // of the slice padded by `padOp`.
378   if (llvm::any_of(
379           llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
380           [](std::tuple<OpFoldResult, OpFoldResult> it) {
381             return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
382           }))
383     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
384 
385   // Exit if the padding values do not match.
386   Attribute padOpPadAttr, padAttr;
387   Value padOpPad = padOp.getConstantPaddingValue();
388   if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) ||
389       !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr)
390     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
391 
392   // Return the padded result if the padding values and sizes match.
393   return sliceOp.getSource();
394 }
395 
396 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
397                           Value outputTensor,
398                           ArrayRef<int64_t> transposeVector) {
399   auto resultTensorType = outputTensor.getType().cast<RankedTensorType>();
400   Type elementType = resultTensorType.getElementType();
401 
402   assert(isPermutationVector(transposeVector) &&
403          "expect transpose vector to be a permutation");
404   assert(transposeVector.size() ==
405              static_cast<size_t>(resultTensorType.getRank()) &&
406          "expect transpose vector size to match result tensor rank");
407 
408   // Compute the transpose and the indentity indexing maps.
409   SmallVector<AffineMap> indexingMaps = {
410       inversePermutation(AffineMap::getPermutationMap(
411           SmallVector<unsigned>(transposeVector.begin(), transposeVector.end()),
412           b.getContext())),
413       AffineMap::getMultiDimIdentityMap(transposeVector.size(),
414                                         b.getContext())};
415   SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(),
416                                                  utils::IteratorType::parallel);
417 
418   // Create a GenericOp to transpose `inputTensor` into `outputTensor`.
419   auto transposeOp =
420       b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor,
421                           indexingMaps, iteratorTypes);
422   Region &body = transposeOp.getRegion();
423   body.push_back(new Block());
424   body.front().addArguments({elementType, elementType}, {loc, loc});
425 
426   // Create the body of the transpose operation.
427   OpBuilder::InsertionGuard g(b);
428   b.setInsertionPointToEnd(&body.front());
429   b.create<YieldOp>(loc, transposeOp.getRegion().front().getArgument(0));
430   return transposeOp;
431 }
432 
433 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
434   auto memrefTypeTo = to.getType().cast<MemRefType>();
435 #ifndef NDEBUG
436   auto memrefTypeFrom = from.getType().cast<MemRefType>();
437   assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
438          "`from` and `to` memref must have the same rank");
439 #endif // NDEBUG
440 
441   AffineMap id =
442       AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
443   SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
444                                                  utils::IteratorType::parallel);
445   return b.create<linalg::GenericOp>(
446       loc,
447       /*inputs=*/from,
448       /*outputs=*/to,
449       /*indexingMaps=*/llvm::ArrayRef({id, id}),
450       /*iteratorTypes=*/iteratorTypes,
451       [](OpBuilder &b, Location loc, ValueRange args) {
452         b.create<linalg::YieldOp>(loc, args.front());
453       });
454 }
455 
456 /// Specialization to build an scf "for" nest.
457 template <>
458 void GenerateLoopNest<scf::ForOp>::doit(
459     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
460     ArrayRef<utils::IteratorType> iteratorTypes,
461     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
462                                   ValueRange)>
463         bodyBuilderFn,
464     ArrayRef<linalg::ProcInfo> procInfo) {
465   assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
466          "expected as many entries for proc info as number of loops, even if "
467          "they are null entries");
468   SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
469                                              ? SmallVector<Value>{}
470                                              : linalgOp.getDpsInitOperands();
471 
472   SmallVector<Value, 4> lbs, ubs, steps;
473   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
474   LoopNest loopNest = mlir::scf::buildLoopNest(
475       b, loc, lbs, ubs, steps, iterArgInitValues,
476       [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
477         assert(iterArgs.size() == iterArgInitValues.size() &&
478                "expect the number of output tensors and iter args to match");
479         SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
480         if (!iterArgs.empty()) {
481           operandValuesToUse = linalgOp.getDpsInputOperands();
482           operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
483         }
484         return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
485       });
486 
487   if (loopNest.loops.empty() || procInfo.empty())
488     return;
489 
490   // Filter out scf.for loops that were created out of parallel dimensions.
491   for (const auto &loop : llvm::enumerate(loopNest.loops)) {
492     if (procInfo[loop.index()].distributionMethod ==
493         DistributionMethod::Cyclic) {
494       mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
495                             procInfo[loop.index()].nprocs);
496     }
497   }
498 }
499 
500 /// Specialization to build affine "for" nest.
501 template <>
502 void GenerateLoopNest<AffineForOp>::doit(
503     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
504     ArrayRef<utils::IteratorType> iteratorTypes,
505     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
506                                   ValueRange)>
507         bodyBuilderFn,
508     ArrayRef<linalg::ProcInfo> /*procInfo*/) {
509   SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
510                                              ? SmallVector<Value>{}
511                                              : linalgOp.getDpsInitOperands();
512   assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
513   SmallVector<Value, 4> lbs, ubs, steps;
514   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
515 
516   // Affine loops require constant steps.
517   SmallVector<int64_t, 4> constantSteps;
518   constantSteps.reserve(steps.size());
519   for (Value v : steps) {
520     auto op = v.getDefiningOp<arith::ConstantIndexOp>();
521     assert(op && "Affine loops require constant steps");
522     constantSteps.push_back(op.value());
523   }
524 
525   mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
526                             [&](OpBuilder &b, Location loc, ValueRange ivs) {
527                               bodyBuilderFn(b, loc, ivs,
528                                             linalgOp->getOperands());
529                             });
530 }
531 
532 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
533 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
534                                        Value nprocs, Value &lb, Value &ub,
535                                        Value &step) {
536   AffineExpr d0, d1;
537   bindDims(b.getContext(), d0, d1);
538   AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
539   lb = makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
540   step = makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
541 }
542 
543 /// Generates a loop nest consisting of scf.parallel and scf.for, depending
544 /// on the `iteratorTypes.` Consecutive parallel loops create a single
545 /// scf.parallel operation; each sequential loop creates a new scf.for
546 /// operation. The body of the innermost loop is populated by
547 /// `bodyBuilderFn` that accepts a range of induction variables for all
548 /// loops. `ivStorage` is used to store the partial list of induction
549 /// variables.
550 // TODO: this function can be made iterative instead. However, it
551 // will have at most as many recursive calls as nested loops, which rarely
552 // exceeds 10.
553 static void generateParallelLoopNest(
554     OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
555     ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
556     ArrayRef<linalg::ProcInfo> procInfo,
557     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
558     SmallVectorImpl<Value> &ivStorage) {
559   assert(lbs.size() == ubs.size());
560   assert(lbs.size() == steps.size());
561   assert(lbs.size() == iteratorTypes.size());
562   assert(procInfo.empty() || (lbs.size() == procInfo.size()));
563 
564   // If there are no (more) loops to be generated, generate the body and be
565   // done with it.
566   if (iteratorTypes.empty()) {
567     bodyBuilderFn(b, loc, ivStorage);
568     return;
569   }
570 
571   // If there are no outer parallel loops, generate one sequential loop and
572   // recurse.
573   if (!isParallelIterator(iteratorTypes.front())) {
574     LoopNest singleLoop = buildLoopNest(
575         b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
576         [&](OpBuilder &b, Location loc, ValueRange ivs) {
577           ivStorage.append(ivs.begin(), ivs.end());
578           generateParallelLoopNest(
579               b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
580               iteratorTypes.drop_front(),
581               procInfo.empty() ? procInfo : procInfo.drop_front(),
582               bodyBuilderFn, ivStorage);
583         });
584     return;
585   }
586 
587   unsigned nLoops = iteratorTypes.size();
588   unsigned numProcessed = 0;
589   DistributionMethod distributionMethod = DistributionMethod::None;
590   if (procInfo.empty()) {
591     numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size();
592   } else {
593     distributionMethod = procInfo.front().distributionMethod;
594     numProcessed =
595         nLoops - procInfo
596                      .drop_while([&](linalg::ProcInfo p) {
597                        return p.distributionMethod == distributionMethod;
598                      })
599                      .size();
600   }
601 
602   auto remainderProcInfo =
603       procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
604   switch (distributionMethod) {
605   case DistributionMethod::None: {
606     // Generate a single parallel loop-nest operation for all outermost
607     // parallel loops and recurse.
608     b.create<scf::ParallelOp>(
609         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
610         steps.take_front(numProcessed),
611         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
612           ivStorage.append(localIvs.begin(), localIvs.end());
613           generateParallelLoopNest(
614               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
615               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
616               iteratorTypes.drop_front(numProcessed), remainderProcInfo,
617               bodyBuilderFn, ivStorage);
618         });
619     return;
620   }
621   case DistributionMethod::Cyclic: {
622     // Generate a single parallel loop-nest operation for all outermost
623     // parallel loops and recurse.
624     b.create<scf::ParallelOp>(
625         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
626         steps.take_front(numProcessed),
627         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
628           ivStorage.append(localIvs.begin(), localIvs.end());
629           generateParallelLoopNest(
630               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
631               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
632               iteratorTypes.drop_front(numProcessed), remainderProcInfo,
633               bodyBuilderFn, ivStorage);
634         });
635     return;
636   }
637   case DistributionMethod::CyclicNumProcsGeNumIters: {
638     // Check (for the processed loops) that the iteration is in-bounds.
639     ArithBuilder ab(b, loc);
640     Value cond = ab.slt(lbs[0], ubs[0]);
641     for (unsigned i = 1; i < numProcessed; ++i)
642       cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
643     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
644     b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) {
645       generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed),
646                                ubs.drop_front(numProcessed),
647                                steps.drop_front(numProcessed),
648                                iteratorTypes.drop_front(numProcessed),
649                                remainderProcInfo, bodyBuilderFn, ivStorage);
650       b.create<scf::YieldOp>(loc, ValueRange{});
651     });
652     return;
653   }
654   case DistributionMethod::CyclicNumProcsEqNumIters:
655     // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
656     // with inner loop generation.
657     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
658     generateParallelLoopNest(
659         b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
660         steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
661         remainderProcInfo, bodyBuilderFn, ivStorage);
662     return;
663   }
664 }
665 
666 /// Specialization for generating a mix of parallel and sequential scf loops.
667 template <>
668 void GenerateLoopNest<scf::ParallelOp>::doit(
669     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
670     ArrayRef<utils::IteratorType> iteratorTypes,
671     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
672                                   ValueRange)>
673         bodyBuilderFn,
674     ArrayRef<linalg::ProcInfo> procInfo) {
675   SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
676                                              ? SmallVector<Value>{}
677                                              : linalgOp.getDpsInitOperands();
678   assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
679   // This function may be passed more iterator types than ranges.
680   assert(iteratorTypes.size() >= loopRanges.size() &&
681          "expected iterator type for all ranges");
682   assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
683          "expected proc information for all loops when present");
684   iteratorTypes = iteratorTypes.take_front(loopRanges.size());
685   SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
686   unsigned numLoops = iteratorTypes.size();
687   ivs.reserve(numLoops);
688   lbsStorage.reserve(numLoops);
689   ubsStorage.reserve(numLoops);
690   stepsStorage.reserve(numLoops);
691 
692   // Get the loop lb, ub, and step.
693   unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
694 
695   // Modify the lb, ub, and step based on the distribution options.
696   for (const auto &it : llvm::enumerate(procInfo)) {
697     if (it.value().distributionMethod != linalg::DistributionMethod::None) {
698       updateBoundsForCyclicDistribution(
699           b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
700           ubsStorage[it.index()], stepsStorage[it.index()]);
701     }
702   }
703   ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
704   generateParallelLoopNest(
705       b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
706       [&](OpBuilder &b, Location loc, ValueRange ivs) {
707         bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
708       },
709       ivs);
710 
711   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
712 }
713 
714 static Value materializeTiledShape(OpBuilder &builder, Location loc,
715                                    Value valueToTile,
716                                    const SliceParameters &sliceParams) {
717   auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
718   auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
719                       .Case([&](MemRefType) {
720                         return builder.create<memref::SubViewOp>(
721                             loc, valueToTile, sliceParams.offsets,
722                             sliceParams.sizes, sliceParams.strides);
723                       })
724                       .Case([&](RankedTensorType) {
725                         return builder.create<tensor::ExtractSliceOp>(
726                             loc, valueToTile, sliceParams.offsets,
727                             sliceParams.sizes, sliceParams.strides);
728                       })
729                       .Default([](ShapedType) -> Operation * {
730                         llvm_unreachable("Unexpected shaped type");
731                       });
732   return sliceOp->getResult(0);
733 }
734 
735 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
736                      ArrayRef<OpFoldResult> tileSizes, AffineMap map,
737                      ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
738                      ArrayRef<OpFoldResult> subShapeSizes,
739                      bool omitPartialTileCheck) {
740   SliceParameters sliceParams =
741       computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
742                              ubs, subShapeSizes, omitPartialTileCheck);
743   return materializeTiledShape(builder, loc, valueToTile, sliceParams);
744 }
745 
746 SliceParameters
747 computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
748                        ArrayRef<OpFoldResult> tileSizes, AffineMap map,
749                        ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
750                        ArrayRef<OpFoldResult> subShapeSizes,
751                        bool omitPartialTileCheck) {
752   auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
753   assert(shapedType && "only shaped types can be tiled");
754   ArrayRef<int64_t> shape = shapedType.getShape();
755   int64_t rank = shapedType.getRank();
756 
757   // Compute offsets/sizes/strides for the tile.
758   SliceParameters sliceParams;
759   sliceParams.offsets.reserve(rank);
760   sliceParams.sizes.reserve(rank);
761   sliceParams.strides.reserve(rank);
762   for (unsigned r = 0; r < rank; ++r) {
763     LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
764     if (!isTiled(map.getSubMap({r}), tileSizes)) {
765       sliceParams.offsets.push_back(builder.getIndexAttr(0));
766       OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
767       sliceParams.sizes.push_back(dim);
768       sliceParams.strides.push_back(builder.getIndexAttr(1));
769       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
770       continue;
771     }
772     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
773 
774     // Tiling creates a new slice at the proper index, the slice step is 1
775     // (i.e. the op does not subsample, stepping occurs in the loop).
776     auto m = map.getSubMap({r});
777     LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
778     IRRewriter rewriter(builder);
779     OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs);
780     sliceParams.offsets.push_back(offset);
781     OpFoldResult closedIntSize =
782         makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
783     // Resulting size needs to be made half open interval again.
784     AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
785     OpFoldResult size =
786         makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize);
787     LLVM_DEBUG(llvm::dbgs()
788                << "computeSliceParameters: raw size: " << size << "\n");
789     LLVM_DEBUG(llvm::dbgs()
790                << "computeSliceParameters: new offset: " << offset << "\n");
791     sliceParams.strides.push_back(builder.getIndexAttr(1));
792 
793     if (omitPartialTileCheck) {
794       // We statically know that the partial/boundary tile condition is
795       // unnecessary.
796       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
797       sliceParams.sizes.push_back(size);
798       continue;
799     }
800 
801     // The size of the subview / extract_slice should be trimmed to avoid
802     // out-of-bounds accesses, unless:
803     // a. We statically know the subshape size divides the shape size evenly.
804     // b. The subshape size is 1. According to the way the loops are set up,
805     //    tensors with "0" dimensions would never be constructed.
806     int64_t shapeSize = shape[r];
807     std::optional<int64_t> sizeCst = getConstantIntValue(size);
808     auto hasTileSizeOne = sizeCst && *sizeCst == 1;
809     auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
810                          ((shapeSize % *sizeCst) == 0);
811     if (!hasTileSizeOne && !dividesEvenly) {
812       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
813                               << ", size: " << size
814                               << ": make sure in bound with affine.min\n");
815 
816       AffineExpr dim0, dim1, dim2;
817       bindDims(builder.getContext(), dim0, dim1, dim2);
818 
819       // Get the dimension size for this dimension. We need to first calculate
820       // the max index and then plus one. This is important because for
821       // convolution ops, we have its input window dimension's affine map of the
822       // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
823       // dimension and `s0` is stride. Directly use the dimension size of
824       // output/filer window dimensions will cause incorrect calculation.
825       AffineMap minusOneMap =
826           AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 - 1}})
827               .front();
828       AffineMap plusOneMap =
829           AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 + 1}})
830               .front();
831       SmallVector<OpFoldResult> maxIndices =
832           llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) {
833             return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap,
834                                                  {ub});
835           }));
836       OpFoldResult maxIndex =
837           makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices);
838       OpFoldResult d =
839           makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex});
840 
841       // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
842       AffineMap minMap = AffineMap::inferFromExprList(
843                              {ArrayRef<AffineExpr>{dim1 - dim2, dim0}})
844                              .front();
845       size =
846           makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});
847     }
848     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
849     sliceParams.sizes.push_back(size);
850   }
851   return sliceParams;
852 }
853 
854 SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
855                                              ArrayRef<OpFoldResult> ivs,
856                                              ArrayRef<OpFoldResult> tileSizes) {
857   SmallVector<OpFoldResult> offsets;
858   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
859     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
860     bool isTiled = !isZeroIndex(tileSizes[idx]);
861     offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
862     LLVM_DEBUG(llvm::dbgs()
863                << "computeTileOffsets: " << offsets.back() << "\n");
864   }
865   return offsets;
866 }
867 
868 SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
869                                            ArrayRef<OpFoldResult> tileSizes,
870                                            ArrayRef<OpFoldResult> sizeBounds) {
871   SmallVector<OpFoldResult> sizes;
872   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
873     bool isTiled = !isZeroIndex(tileSizes[idx]);
874     // Before composing, we need to make range a closed interval.
875     OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
876     AffineExpr d0 = getAffineDimExpr(0, b.getContext());
877     IRRewriter rewriter(b);
878     sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size));
879     LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
880   }
881   return sizes;
882 }
883 
884 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
885   if (op.hasBufferSemantics())
886     return {};
887   return llvm::to_vector(
888       llvm::map_range(op.getDpsInitOperands(), [&](OpOperand *opOperand) {
889         return operands[opOperand->getOperandNumber()].getType();
890       }));
891 }
892 
893 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
894                                     LinalgOp op, ValueRange operands,
895                                     ValueRange results) {
896   if (op.hasBufferSemantics())
897     return {};
898   SmallVector<Value> tensorResults;
899   tensorResults.reserve(results.size());
900   // Insert a insert_slice for each output tensor.
901   unsigned resultIdx = 0;
902   for (OpOperand *opOperand : op.getDpsInitOperands()) {
903     // TODO: use an interface/adaptor to avoid leaking position in
904     // `tiledOperands`.
905     Value outputTensor = operands[opOperand->getOperandNumber()];
906     if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
907       Value inserted = builder.create<tensor::InsertSliceOp>(
908           loc, sliceOp.getSource().getType(), results[resultIdx],
909           sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
910           sliceOp.getStrides(), sliceOp.getStaticOffsets(),
911           sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
912       tensorResults.push_back(inserted);
913     } else {
914       tensorResults.push_back(results[resultIdx]);
915     }
916     ++resultIdx;
917   }
918   return tensorResults;
919 }
920 
921 SmallVector<std::optional<SliceParameters>>
922 computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
923                           ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
924                           ArrayRef<OpFoldResult> tileSizes,
925                           ArrayRef<OpFoldResult> sizeBounds,
926                           bool omitPartialTileCheck) {
927   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
928                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
929                            [](OpFoldResult v) { return !isZeroIndex(v); })) &&
930          "expected as many ivs as non-zero sizes");
931 
932   // Construct (potentially temporary) mins and maxes on which to apply maps
933   // that define tile subshapes.
934   SmallVector<OpFoldResult> lbs =
935       computeTileOffsets(builder, loc, ivs, tileSizes);
936   SmallVector<OpFoldResult> subShapeSizes =
937       computeTileSizes(builder, loc, tileSizes, sizeBounds);
938 
939   assert(static_cast<int64_t>(valuesToTile.size()) <=
940              linalgOp->getNumOperands() &&
941          "more value to tile than operands.");
942   SmallVector<std::optional<SliceParameters>> allSliceParams;
943   allSliceParams.reserve(valuesToTile.size());
944   for (auto [opOperand, val] :
945        llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
946     Value shapedOp = val;
947     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
948     AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
949     // Use `opOperand` as is if it is not tiled and not an output tensor. Having
950     // an extract/insert slice pair for all output tensors simplifies follow up
951     // transformations such as padding and bufferization since the
952     // extract/insert slice pairs make the accessed iteration argument
953     // subdomains explicit.
954 
955     Type operandType = opOperand.get().getType();
956     if (!isTiled(map, tileSizes) && !(operandType.isa<RankedTensorType>() &&
957                                       linalgOp.isDpsInit(&opOperand))) {
958       allSliceParams.push_back(std::nullopt);
959       LLVM_DEBUG(llvm::dbgs()
960                  << ": not tiled: use shape: " << operandType << "\n");
961       continue;
962     }
963     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
964 
965     allSliceParams.push_back(computeSliceParameters(
966         builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
967         omitPartialTileCheck));
968   }
969 
970   return allSliceParams;
971 }
972 
973 SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
974                                    LinalgOp linalgOp, ValueRange valuesToTile,
975                                    ArrayRef<OpFoldResult> ivs,
976                                    ArrayRef<OpFoldResult> tileSizes,
977                                    ArrayRef<OpFoldResult> sizeBounds,
978                                    bool omitPartialTileCheck) {
979   SmallVector<std::optional<SliceParameters>> allSliceParameter =
980       computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
981                                 tileSizes, sizeBounds, omitPartialTileCheck);
982   SmallVector<Value> tiledShapes;
983   for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
984     Value valueToTile = std::get<0>(item);
985     std::optional<SliceParameters> sliceParams = std::get<1>(item);
986     tiledShapes.push_back(
987         sliceParams.has_value()
988             ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
989             : valueToTile);
990   }
991   return tiledShapes;
992 }
993 
994 void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
995                    ArrayRef<OpFoldResult> offsets) {
996   IRRewriter rewriter(b);
997   offsetIndices(rewriter, linalgOp, offsets);
998 }
999 
1000 void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
1001                    ArrayRef<OpFoldResult> offsets) {
1002   if (!linalgOp.hasIndexSemantics())
1003     return;
1004 
1005   for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
1006     if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
1007       continue;
1008     OpBuilder::InsertionGuard guard(b);
1009     b.setInsertionPointAfter(indexOp);
1010     AffineExpr index, offset;
1011     bindDims(b.getContext(), index, offset);
1012     OpFoldResult applied = makeComposedFoldedAffineApply(
1013         b, indexOp.getLoc(), index + offset,
1014         {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
1015     Value materialized =
1016         getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
1017     b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) {
1018       return use.getOwner() != materialized.getDefiningOp();
1019     });
1020   }
1021 }
1022 
1023 /// Get the reassociation maps to fold the result of a extract_slice (or source
1024 /// of a insert_slice) operation with given offsets, and sizes to its
1025 /// rank-reduced version. This is only done for the cases where the size is 1
1026 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
1027 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
1028 /// potentially cannot be handled).
1029 std::optional<SmallVector<ReassociationIndices>>
1030 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
1031   SmallVector<ReassociationIndices> reassociation;
1032   ReassociationIndices curr;
1033   for (const auto &it : llvm::enumerate(mixedSizes)) {
1034     auto dim = it.index();
1035     auto size = it.value();
1036     curr.push_back(dim);
1037     auto attr = size.dyn_cast<Attribute>();
1038     if (attr && attr.cast<IntegerAttr>().getInt() == 1)
1039       continue;
1040     reassociation.emplace_back(ReassociationIndices{});
1041     std::swap(reassociation.back(), curr);
1042   }
1043   // When the reassociations are not empty, then fold the remaining
1044   // unit-dimensions into the last dimension.  If the reassociations so far is
1045   // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
1046   if (!curr.empty() && !reassociation.empty())
1047     reassociation.back().append(curr.begin(), curr.end());
1048   return reassociation;
1049 }
1050 
1051 /// Return the identity numeric value associated to the give op.
1052 std::optional<Attribute> getNeutralElement(Operation *op) {
1053   // Builder only used as helper for attribute creation.
1054   OpBuilder b(op->getContext());
1055   Type resultType = op->getResult(0).getType();
1056   if (auto floatType = resultType.dyn_cast<FloatType>()) {
1057     const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
1058     if (isa<arith::AddFOp>(op))
1059       return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
1060     if (isa<arith::MulFOp>(op))
1061       return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
1062     if (isa<arith::MaxFOp>(op))
1063       return b.getFloatAttr(resultType,
1064                             llvm::APFloat::getInf(semantic, /*Negative=*/true));
1065     if (isa<arith::MinFOp>(op))
1066       return b.getFloatAttr(
1067           resultType, llvm::APFloat::getInf(semantic, /*Negative=*/false));
1068     return std::nullopt;
1069   }
1070   if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
1071     return b.getIntegerAttr(resultType, 0);
1072   if (isa<arith::AndIOp>(op))
1073     return b.getIntegerAttr(resultType, -1);
1074   if (isa<arith::MaxSIOp>(op))
1075     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
1076   if (isa<arith::MinSIOp>(op))
1077     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
1078   if (isa<arith::MulIOp>(op))
1079     return b.getIntegerAttr(resultType, 1);
1080   return std::nullopt;
1081 }
1082 
1083 } // namespace linalg
1084 } // namespace mlir
1085