xref: /llvm-project/mlir/lib/Dialect/Linalg/Utils/Utils.cpp (revision f453589039129d90c46930245129d313f0b38714)
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Utils/Utils.h"
14 
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/Arith/Utils/Utils.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Linalg/IR/Linalg.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h"
27 #include "mlir/Dialect/Tensor/Utils/Utils.h"
28 #include "mlir/Dialect/Utils/IndexingUtils.h"
29 #include "mlir/Dialect/Utils/StaticValueUtils.h"
30 #include "mlir/IR/AffineExpr.h"
31 #include "mlir/IR/AffineExprVisitor.h"
32 #include "mlir/IR/AffineMap.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/Pass/Pass.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/Debug.h"
38 #include <optional>
39 
40 #define DEBUG_TYPE "linalg-utils"
41 
42 using namespace mlir;
43 using namespace presburger;
44 using namespace mlir::linalg;
45 using namespace mlir::scf;
46 
47 static bool isZero(OpFoldResult v) {
48   if (!v)
49     return false;
50   if (auto attr = v.dyn_cast<Attribute>()) {
51     IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
52     return intAttr && intAttr.getValue().isZero();
53   }
54   if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>())
55     return cst.value() == 0;
56   return false;
57 }
58 
59 namespace {
60 
61 // Helper visitor to determine whether an AffineExpr is tiled.
62 // This is achieved by traversing every AffineDimExpr with position `pos` and
63 // checking whether the corresponding `tileSizes[pos]` is non-zero.
64 // This also enforces only positive coefficients occur in multiplications.
65 //
66 // Example:
67 //   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
68 //
69 struct TileCheck : public AffineExprVisitor<TileCheck> {
70   TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
71 
72   void visitDimExpr(AffineDimExpr expr) {
73     isTiled |= !isZero(tileSizes[expr.getPosition()]);
74   }
75   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
76     visit(expr.getLHS());
77     visit(expr.getRHS());
78     if (expr.getKind() == mlir::AffineExprKind::Mul)
79       assert(expr.getRHS().cast<AffineConstantExpr>().getValue() > 0 &&
80              "nonpositive multiplying coefficient");
81   }
82   bool isTiled = false;
83   ArrayRef<OpFoldResult> tileSizes;
84 };
85 
86 } // namespace
87 
88 static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
89   if (!expr)
90     return false;
91   TileCheck t(tileSizes);
92   t.visit(expr);
93   return t.isTiled;
94 }
95 
96 // Checks whether the `map  varies with respect to a non-zero `tileSize`.
97 static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
98   if (!map)
99     return false;
100   for (unsigned r = 0; r < map.getNumResults(); ++r)
101     if (isTiled(map.getResult(r), tileSizes))
102       return true;
103   return false;
104 }
105 
106 std::optional<RegionMatcher::BinaryOpKind>
107 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
108   auto &region = op.getRegion();
109   if (!llvm::hasSingleElement(region))
110     return std::nullopt;
111 
112   Block &block = region.front();
113   if (block.getNumArguments() != 2 ||
114       !block.getArgument(0).getType().isSignlessIntOrFloat() ||
115       !block.getArgument(1).getType().isSignlessIntOrFloat())
116     return std::nullopt;
117 
118   auto &ops = block.getOperations();
119   if (!llvm::hasSingleElement(block.without_terminator()))
120     return std::nullopt;
121 
122   using mlir::matchers::m_Val;
123   auto a = m_Val(block.getArgument(0));
124   auto b = m_Val(block.getArgument(1));
125 
126   auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b));
127   if (addPattern.match(&ops.back()))
128     return BinaryOpKind::IAdd;
129 
130   return std::nullopt;
131 }
132 
133 /// Explicit instantiation of loop nest generator for different loop types.
134 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
135 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
136 template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
137 
138 /// Given a list of subview ranges, extract individual values for lower, upper
139 /// bounds and steps and put them into the corresponding vectors.
140 static void unpackRanges(OpBuilder &builder, Location loc,
141                          ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
142                          SmallVectorImpl<Value> &ubs,
143                          SmallVectorImpl<Value> &steps) {
144   for (Range range : ranges) {
145     lbs.emplace_back(
146         getValueOrCreateConstantIndexOp(builder, loc, range.offset));
147     ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size));
148     steps.emplace_back(
149         getValueOrCreateConstantIndexOp(builder, loc, range.stride));
150   }
151 }
152 
153 namespace mlir {
154 namespace linalg {
155 
156 bool allIndexingsAreProjectedPermutation(LinalgOp op) {
157   return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
158     return m.isProjectedPermutation(/*allowZeroInResults=*/true);
159   });
160 }
161 
162 bool hasOnlyScalarElementwiseOp(Region &r) {
163   if (!llvm::hasSingleElement(r))
164     return false;
165   for (Operation &op : r.front()) {
166     if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
167               linalg::YieldOp, linalg::IndexOp>(op) ||
168           OpTrait::hasElementwiseMappableTraits(&op)) ||
169         llvm::any_of(op.getResultTypes(),
170                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
171       return false;
172   }
173   return true;
174 }
175 
176 bool isElementwise(LinalgOp op) {
177   if (op.getNumLoops() != op.getNumParallelLoops())
178     return false;
179 
180   if (!allIndexingsAreProjectedPermutation(op))
181     return false;
182 
183   // TODO: relax the restrictions on indexing map.
184   for (OpOperand *opOperand : op.getDpsInitOperands()) {
185     if (!op.getMatchingIndexingMap(opOperand).isPermutation())
186       return false;
187   }
188   return hasOnlyScalarElementwiseOp(op->getRegion(0));
189 }
190 
191 bool isParallelIterator(utils::IteratorType iteratorType) {
192   return iteratorType == utils::IteratorType::parallel;
193 }
194 
195 bool isReductionIterator(utils::IteratorType iteratorType) {
196   return iteratorType == utils::IteratorType::reduction;
197 }
198 
199 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
200 /// the type of `source`.
201 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) {
202   if (source.getType().isa<UnrankedMemRefType, MemRefType>())
203     return b.createOrFold<memref::DimOp>(loc, source, dim);
204   if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
205     return b.createOrFold<tensor::DimOp>(loc, source, dim);
206   llvm_unreachable("Expected MemRefType or TensorType");
207 }
208 
209 OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source,
210                                int64_t dim) {
211   auto shapedType = source.getType().cast<ShapedType>();
212   if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
213     return createOrFoldDimOp(b, loc, source, dim);
214   return b.getIndexAttr(shapedType.getDimSize(dim));
215 }
216 
217 /// Given an operation, retrieves the value of each dynamic dimension through
218 /// constructing the necessary DimOp operators.
219 SmallVector<Value, 4> getDynOperands(Location loc, Value val, OpBuilder &b) {
220   SmallVector<Value, 4> dynOperands;
221   auto shapedType = val.getType().cast<ShapedType>();
222   for (const auto &dim : llvm::enumerate(shapedType.getShape())) {
223     if (dim.value() == ShapedType::kDynamic)
224       dynOperands.push_back(createOrFoldDimOp(b, loc, val, dim.index()));
225   }
226   return dynOperands;
227 }
228 
229 void getUpperBoundForIndex(Value value, AffineMap &boundMap,
230                            SmallVectorImpl<Value> &boundOperands,
231                            bool constantRequired) {
232   // Initialize `boundMap` and `boundOperands` to the identity returning
233   // `value`. This combination is the default result of the method if no
234   // simplification is possible.
235   assert(value.getType().isIndex() && "expect value to have index type");
236   boundMap = AffineMap::getMultiDimIdentityMap(1, value.getContext());
237   boundOperands.assign({value});
238   canonicalizeMapAndOperands(&boundMap, &boundOperands);
239 
240   // Continue only if there is an affine index computation to simplify.
241   Operation *definingOp = value.getDefiningOp();
242   if (!definingOp || !isa<AffineApplyOp, AffineMinOp>(definingOp))
243     return;
244 
245   // Get the backward slice containing the affine index computation.
246   SetVector<Operation *> backwardSlice;
247   getBackwardSlice(definingOp, &backwardSlice, [](Operation *op) {
248     return isa<AffineApplyOp, AffineMinOp>(op);
249   });
250   backwardSlice.insert(definingOp);
251 
252   // Setup a system of affine constraints that describe the index computation.
253   FlatAffineValueConstraints constraints;
254 
255   // Helper to find or create an identifier for the given value.
256   auto findOrCreateId = [&](Value value) {
257     if (!constraints.containsVar(value)) {
258       constraints.appendDimVar(value);
259       return true;
260     }
261     unsigned pos;
262     constraints.findVar(value, &pos);
263     return pos < constraints.getNumDimVars();
264   };
265   // Helper to get the position for the given value.
266   auto getPosition = [&](Value value) {
267     unsigned pos;
268     bool exists = constraints.findVar(value, &pos);
269     (void)exists;
270     assert(exists && "expect to find the identifier");
271     return pos;
272   };
273 
274   // Add the affine operations in `backwardSlice` to the constraints.
275   for (Operation *op : llvm::reverse(backwardSlice)) {
276     // Add an identifier for all op results and operands.
277     if (!(llvm::all_of(op->getResults(), findOrCreateId) &&
278           llvm::all_of(op->getOperands(), findOrCreateId)))
279       return;
280 
281     // Add AffineApplyOps to the constraints.
282     if (auto applyOp = dyn_cast<AffineApplyOp>(op)) {
283       AffineMap map = constraints.computeAlignedMap(applyOp.getAffineMap(),
284                                                     applyOp.getOperands());
285       if (failed(constraints.addBound(IntegerPolyhedron::EQ,
286                                       getPosition(applyOp.getResult()), map)))
287         return;
288       continue;
289     }
290     // Add AffineMinOps to the constraints.
291     auto minOp = cast<AffineMinOp>(op);
292     AffineMap map = constraints.computeAlignedMap(minOp.getAffineMap(),
293                                                   minOp.getOperands());
294     if (failed(constraints.addBound(IntegerPolyhedron::UB,
295                                     getPosition(minOp.getResult()), map,
296                                     /*isClosedBound=*/true)))
297       return;
298   }
299 
300   // Obtain an upper bound for the affine index computation by projecting out
301   // all temporary results and expressing the upper bound for `value` in terms
302   // of the terminals of the index computation.
303   unsigned pos = getPosition(value);
304   if (constantRequired) {
305     auto ubConst = constraints.getConstantBound64(
306         FlatAffineValueConstraints::BoundType::UB, pos);
307     if (!ubConst)
308       return;
309 
310     boundMap = AffineMap::getConstantMap(*ubConst, value.getContext());
311     return;
312   }
313 
314   SmallVector<AffineMap> lowerBounds(1), upperBounds(1);
315   constraints.getSliceBounds(pos, 1, value.getContext(), &lowerBounds,
316                              &upperBounds,
317                              /*getClosedUB=*/true);
318   // Verify `upperBounds[0]` is valid and has at least one result.
319   if (!upperBounds[0] || upperBounds[0].getNumResults() == 0)
320     return;
321 
322   // Set `boundMap` and `boundOperands` to the computed upper bound.
323   boundMap = upperBounds[0];
324   constraints.getAllValues(&boundOperands);
325   erase_value(boundOperands, value);
326   canonicalizeMapAndOperands(&boundMap, &boundOperands);
327 }
328 
329 FailureOr<int64_t> getConstantUpperBoundForIndex(Value value) {
330   // Compute an upper bound for `value`.
331   AffineMap boundMap;
332   SmallVector<Value> boundOperands;
333   getUpperBoundForIndex(value, boundMap, boundOperands,
334                         /*constantRequired=*/true);
335 
336   // Search the results of `boundMap` for constant upper bounds.
337   SmallVector<int64_t> constantBounds;
338   for (AffineExpr result : boundMap.getResults())
339     if (auto constExpr = result.dyn_cast<AffineConstantExpr>())
340       constantBounds.push_back(constExpr.getValue());
341 
342   // Return the minimal upper bound or failure if none is found.
343   if (constantBounds.empty())
344     return failure();
345   return *std::min_element(constantBounds.begin(), constantBounds.end());
346 }
347 
348 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
349                             Value source, Value pad, bool nofold) {
350   // Exit if `source` is not defined by an ExtractSliceOp.
351   auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
352   if (!sliceOp)
353     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
354 
355   // Search the `source` use-def chain for padded LinalgOps.
356   Value current = sliceOp.getSource();
357   while (current) {
358     auto linalgOp = current.getDefiningOp<LinalgOp>();
359     if (!linalgOp)
360       break;
361     OpResult opResult = current.cast<OpResult>();
362     current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
363   }
364   auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
365 
366   // Exit if the search fails to match a tensor::PadOp at the end of the matched
367   // LinalgOp sequence.
368   if (!padOp)
369     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
370 
371   // Exit if the padded result type does not match.
372   if (sliceOp.getSource().getType() != type)
373     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
374 
375   // Exit if the LinalgOps are not high padded.
376   if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
377         return getConstantIntValue(ofr) != static_cast<int64_t>(0);
378       }))
379     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
380 
381   // Exit if `padOpSliceOp`, which defines the slice used by
382   // `padOp`, is rank-reducing.
383   auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
384   if (!padOpSliceOp ||
385       sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
386     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
387 
388   // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
389   // of the slice padded by `padOp`.
390   if (llvm::any_of(
391           llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
392           [](std::tuple<OpFoldResult, OpFoldResult> it) {
393             return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
394           }))
395     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
396 
397   // Exit if the padding values do not match.
398   Attribute padOpPadAttr, padAttr;
399   Value padOpPad = padOp.getConstantPaddingValue();
400   if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) ||
401       !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr)
402     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
403 
404   // Return the padded result if the padding values and sizes match.
405   return sliceOp.getSource();
406 }
407 
408 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
409                           Value outputTensor,
410                           ArrayRef<int64_t> transposeVector) {
411   auto resultTensorType = outputTensor.getType().cast<RankedTensorType>();
412   Type elementType = resultTensorType.getElementType();
413 
414   assert(isPermutationVector(transposeVector) &&
415          "expect transpose vector to be a permutation");
416   assert(transposeVector.size() ==
417              static_cast<size_t>(resultTensorType.getRank()) &&
418          "expect transpose vector size to match result tensor rank");
419 
420   // Compute the transpose and the indentity indexing maps.
421   SmallVector<AffineMap> indexingMaps = {
422       inversePermutation(AffineMap::getPermutationMap(
423           SmallVector<unsigned>(transposeVector.begin(), transposeVector.end()),
424           b.getContext())),
425       AffineMap::getMultiDimIdentityMap(transposeVector.size(),
426                                         b.getContext())};
427   SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(),
428                                                  utils::IteratorType::parallel);
429 
430   // Create a GenericOp to transpose `inputTensor` into `outputTensor`.
431   auto transposeOp =
432       b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor,
433                           indexingMaps, iteratorTypes);
434   Region &body = transposeOp.getRegion();
435   body.push_back(new Block());
436   body.front().addArguments({elementType, elementType}, {loc, loc});
437 
438   // Create the body of the transpose operation.
439   OpBuilder::InsertionGuard g(b);
440   b.setInsertionPointToEnd(&body.front());
441   b.create<YieldOp>(loc, transposeOp.getRegion().front().getArgument(0));
442   return transposeOp;
443 }
444 
445 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
446   auto memrefTypeTo = to.getType().cast<MemRefType>();
447 #ifndef NDEBUG
448   auto memrefTypeFrom = from.getType().cast<MemRefType>();
449   assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
450          "`from` and `to` memref must have the same rank");
451 #endif // NDEBUG
452 
453   AffineMap id =
454       AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
455   SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
456                                                  utils::IteratorType::parallel);
457   return b.create<linalg::GenericOp>(
458       loc,
459       /*inputs=*/from,
460       /*outputs=*/to,
461       /*indexingMaps=*/llvm::ArrayRef({id, id}),
462       /*iteratorTypes=*/iteratorTypes,
463       [](OpBuilder &b, Location loc, ValueRange args) {
464         b.create<linalg::YieldOp>(loc, args.front());
465       });
466 }
467 
468 /// Specialization to build an scf "for" nest.
469 template <>
470 void GenerateLoopNest<scf::ForOp>::doit(
471     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
472     ArrayRef<utils::IteratorType> iteratorTypes,
473     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
474                                   ValueRange)>
475         bodyBuilderFn,
476     ArrayRef<linalg::ProcInfo> procInfo) {
477   assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
478          "expected as many entries for proc info as number of loops, even if "
479          "they are null entries");
480   SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
481                                              ? SmallVector<Value>{}
482                                              : linalgOp.getDpsInitOperands();
483 
484   SmallVector<Value, 4> lbs, ubs, steps;
485   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
486   LoopNest loopNest = mlir::scf::buildLoopNest(
487       b, loc, lbs, ubs, steps, iterArgInitValues,
488       [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
489         assert(iterArgs.size() == iterArgInitValues.size() &&
490                "expect the number of output tensors and iter args to match");
491         SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
492         if (!iterArgs.empty()) {
493           operandValuesToUse = linalgOp.getDpsInputOperands();
494           operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
495         }
496         return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
497       });
498 
499   if (loopNest.loops.empty() || procInfo.empty())
500     return;
501 
502   // Filter out scf.for loops that were created out of parallel dimensions.
503   for (const auto &loop : llvm::enumerate(loopNest.loops)) {
504     if (procInfo[loop.index()].distributionMethod ==
505         DistributionMethod::Cyclic) {
506       mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
507                             procInfo[loop.index()].nprocs);
508     }
509   }
510 }
511 
512 /// Specialization to build affine "for" nest.
513 template <>
514 void GenerateLoopNest<AffineForOp>::doit(
515     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
516     ArrayRef<utils::IteratorType> iteratorTypes,
517     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
518                                   ValueRange)>
519         bodyBuilderFn,
520     ArrayRef<linalg::ProcInfo> /*procInfo*/) {
521   SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
522                                              ? SmallVector<Value>{}
523                                              : linalgOp.getDpsInitOperands();
524   assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
525   SmallVector<Value, 4> lbs, ubs, steps;
526   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
527 
528   // Affine loops require constant steps.
529   SmallVector<int64_t, 4> constantSteps;
530   constantSteps.reserve(steps.size());
531   for (Value v : steps) {
532     auto op = v.getDefiningOp<arith::ConstantIndexOp>();
533     assert(op && "Affine loops require constant steps");
534     constantSteps.push_back(op.value());
535   }
536 
537   mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
538                             [&](OpBuilder &b, Location loc, ValueRange ivs) {
539                               bodyBuilderFn(b, loc, ivs,
540                                             linalgOp->getOperands());
541                             });
542 }
543 
544 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
545 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
546                                        Value nprocs, Value &lb, Value &ub,
547                                        Value &step) {
548   AffineExpr d0, d1;
549   bindDims(b.getContext(), d0, d1);
550   AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
551   lb = makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
552   step = makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
553 }
554 
555 /// Generates a loop nest consisting of scf.parallel and scf.for, depending
556 /// on the `iteratorTypes.` Consecutive parallel loops create a single
557 /// scf.parallel operation; each sequential loop creates a new scf.for
558 /// operation. The body of the innermost loop is populated by
559 /// `bodyBuilderFn` that accepts a range of induction variables for all
560 /// loops. `ivStorage` is used to store the partial list of induction
561 /// variables.
562 // TODO: this function can be made iterative instead. However, it
563 // will have at most as many recursive calls as nested loops, which rarely
564 // exceeds 10.
565 static void generateParallelLoopNest(
566     OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
567     ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
568     ArrayRef<linalg::ProcInfo> procInfo,
569     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
570     SmallVectorImpl<Value> &ivStorage) {
571   assert(lbs.size() == ubs.size());
572   assert(lbs.size() == steps.size());
573   assert(lbs.size() == iteratorTypes.size());
574   assert(procInfo.empty() || (lbs.size() == procInfo.size()));
575 
576   // If there are no (more) loops to be generated, generate the body and be
577   // done with it.
578   if (iteratorTypes.empty()) {
579     bodyBuilderFn(b, loc, ivStorage);
580     return;
581   }
582 
583   // If there are no outer parallel loops, generate one sequential loop and
584   // recurse.
585   if (!isParallelIterator(iteratorTypes.front())) {
586     LoopNest singleLoop = buildLoopNest(
587         b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
588         [&](OpBuilder &b, Location loc, ValueRange ivs) {
589           ivStorage.append(ivs.begin(), ivs.end());
590           generateParallelLoopNest(
591               b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
592               iteratorTypes.drop_front(),
593               procInfo.empty() ? procInfo : procInfo.drop_front(),
594               bodyBuilderFn, ivStorage);
595         });
596     return;
597   }
598 
599   unsigned nLoops = iteratorTypes.size();
600   unsigned numProcessed = 0;
601   DistributionMethod distributionMethod = DistributionMethod::None;
602   if (procInfo.empty()) {
603     numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size();
604   } else {
605     distributionMethod = procInfo.front().distributionMethod;
606     numProcessed =
607         nLoops - procInfo
608                      .drop_while([&](linalg::ProcInfo p) {
609                        return p.distributionMethod == distributionMethod;
610                      })
611                      .size();
612   }
613 
614   auto remainderProcInfo =
615       procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
616   switch (distributionMethod) {
617   case DistributionMethod::None: {
618     // Generate a single parallel loop-nest operation for all outermost
619     // parallel loops and recurse.
620     b.create<scf::ParallelOp>(
621         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
622         steps.take_front(numProcessed),
623         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
624           ivStorage.append(localIvs.begin(), localIvs.end());
625           generateParallelLoopNest(
626               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
627               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
628               iteratorTypes.drop_front(numProcessed), remainderProcInfo,
629               bodyBuilderFn, ivStorage);
630         });
631     return;
632   }
633   case DistributionMethod::Cyclic: {
634     // Generate a single parallel loop-nest operation for all outermost
635     // parallel loops and recurse.
636     b.create<scf::ParallelOp>(
637         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
638         steps.take_front(numProcessed),
639         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
640           ivStorage.append(localIvs.begin(), localIvs.end());
641           generateParallelLoopNest(
642               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
643               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
644               iteratorTypes.drop_front(numProcessed), remainderProcInfo,
645               bodyBuilderFn, ivStorage);
646         });
647     return;
648   }
649   case DistributionMethod::CyclicNumProcsGeNumIters: {
650     // Check (for the processed loops) that the iteration is in-bounds.
651     ArithBuilder ab(b, loc);
652     Value cond = ab.slt(lbs[0], ubs[0]);
653     for (unsigned i = 1; i < numProcessed; ++i)
654       cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
655     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
656     b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) {
657       generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed),
658                                ubs.drop_front(numProcessed),
659                                steps.drop_front(numProcessed),
660                                iteratorTypes.drop_front(numProcessed),
661                                remainderProcInfo, bodyBuilderFn, ivStorage);
662       b.create<scf::YieldOp>(loc, ValueRange{});
663     });
664     return;
665   }
666   case DistributionMethod::CyclicNumProcsEqNumIters:
667     // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
668     // with inner loop generation.
669     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
670     generateParallelLoopNest(
671         b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
672         steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
673         remainderProcInfo, bodyBuilderFn, ivStorage);
674     return;
675   }
676 }
677 
678 /// Specialization for generating a mix of parallel and sequential scf loops.
679 template <>
680 void GenerateLoopNest<scf::ParallelOp>::doit(
681     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
682     ArrayRef<utils::IteratorType> iteratorTypes,
683     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
684                                   ValueRange)>
685         bodyBuilderFn,
686     ArrayRef<linalg::ProcInfo> procInfo) {
687   SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
688                                              ? SmallVector<Value>{}
689                                              : linalgOp.getDpsInitOperands();
690   assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
691   // This function may be passed more iterator types than ranges.
692   assert(iteratorTypes.size() >= loopRanges.size() &&
693          "expected iterator type for all ranges");
694   assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
695          "expected proc information for all loops when present");
696   iteratorTypes = iteratorTypes.take_front(loopRanges.size());
697   SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
698   unsigned numLoops = iteratorTypes.size();
699   ivs.reserve(numLoops);
700   lbsStorage.reserve(numLoops);
701   ubsStorage.reserve(numLoops);
702   stepsStorage.reserve(numLoops);
703 
704   // Get the loop lb, ub, and step.
705   unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
706 
707   // Modify the lb, ub, and step based on the distribution options.
708   for (const auto &it : llvm::enumerate(procInfo)) {
709     if (it.value().distributionMethod != linalg::DistributionMethod::None) {
710       updateBoundsForCyclicDistribution(
711           b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
712           ubsStorage[it.index()], stepsStorage[it.index()]);
713     }
714   }
715   ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
716   generateParallelLoopNest(
717       b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
718       [&](OpBuilder &b, Location loc, ValueRange ivs) {
719         bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
720       },
721       ivs);
722 
723   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
724 }
725 
726 static Value materializeTiledShape(OpBuilder &builder, Location loc,
727                                    Value valueToTile,
728                                    const SliceParameters &sliceParams) {
729   auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
730   auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
731                       .Case([&](MemRefType) {
732                         return builder.create<memref::SubViewOp>(
733                             loc, valueToTile, sliceParams.offsets,
734                             sliceParams.sizes, sliceParams.strides);
735                       })
736                       .Case([&](RankedTensorType) {
737                         return builder.create<tensor::ExtractSliceOp>(
738                             loc, valueToTile, sliceParams.offsets,
739                             sliceParams.sizes, sliceParams.strides);
740                       })
741                       .Default([](ShapedType) -> Operation * {
742                         llvm_unreachable("Unexpected shaped type");
743                       });
744   return sliceOp->getResult(0);
745 }
746 
747 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
748                      ArrayRef<OpFoldResult> tileSizes, AffineMap map,
749                      ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
750                      ArrayRef<OpFoldResult> subShapeSizes,
751                      bool omitPartialTileCheck) {
752   SliceParameters sliceParams =
753       computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
754                              ubs, subShapeSizes, omitPartialTileCheck);
755   return materializeTiledShape(builder, loc, valueToTile, sliceParams);
756 }
757 
758 SliceParameters
759 computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
760                        ArrayRef<OpFoldResult> tileSizes, AffineMap map,
761                        ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
762                        ArrayRef<OpFoldResult> subShapeSizes,
763                        bool omitPartialTileCheck) {
764   auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
765   assert(shapedType && "only shaped types can be tiled");
766   ArrayRef<int64_t> shape = shapedType.getShape();
767   int64_t rank = shapedType.getRank();
768 
769   // Compute offsets/sizes/strides for the tile.
770   SliceParameters sliceParams;
771   sliceParams.offsets.reserve(rank);
772   sliceParams.sizes.reserve(rank);
773   sliceParams.strides.reserve(rank);
774   for (unsigned r = 0; r < rank; ++r) {
775     LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
776     if (!isTiled(map.getSubMap({r}), tileSizes)) {
777       sliceParams.offsets.push_back(builder.getIndexAttr(0));
778       OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
779       sliceParams.sizes.push_back(dim);
780       sliceParams.strides.push_back(builder.getIndexAttr(1));
781       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
782       continue;
783     }
784     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
785 
786     // Tiling creates a new slice at the proper index, the slice step is 1
787     // (i.e. the op does not subsample, stepping occurs in the loop).
788     auto m = map.getSubMap({r});
789     LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
790     IRRewriter rewriter(builder);
791     OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs);
792     sliceParams.offsets.push_back(offset);
793     OpFoldResult closedIntSize =
794         makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
795     // Resulting size needs to be made half open interval again.
796     AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
797     OpFoldResult size =
798         makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize);
799     LLVM_DEBUG(llvm::dbgs()
800                << "computeSliceParameters: raw size: " << size << "\n");
801     LLVM_DEBUG(llvm::dbgs()
802                << "computeSliceParameters: new offset: " << offset << "\n");
803     sliceParams.strides.push_back(builder.getIndexAttr(1));
804 
805     if (omitPartialTileCheck) {
806       // We statically know that the partial/boundary tile condition is
807       // unnecessary.
808       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
809       sliceParams.sizes.push_back(size);
810       continue;
811     }
812 
813     // The size of the subview / extract_slice should be trimmed to avoid
814     // out-of-bounds accesses, unless:
815     // a. We statically know the subshape size divides the shape size evenly.
816     // b. The subshape size is 1. According to the way the loops are set up,
817     //    tensors with "0" dimensions would never be constructed.
818     int64_t shapeSize = shape[r];
819     std::optional<int64_t> sizeCst = getConstantIntValue(size);
820     auto hasTileSizeOne = sizeCst && *sizeCst == 1;
821     auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
822                          ((shapeSize % *sizeCst) == 0);
823     if (!hasTileSizeOne && !dividesEvenly) {
824       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
825                               << ", size: " << size
826                               << ": make sure in bound with affine.min\n");
827 
828       AffineExpr dim0, dim1, dim2;
829       bindDims(builder.getContext(), dim0, dim1, dim2);
830 
831       // Get the dimension size for this dimension. We need to first calculate
832       // the max index and then plus one. This is important because for
833       // convolution ops, we have its input window dimension's affine map of the
834       // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
835       // dimension and `s0` is stride. Directly use the dimension size of
836       // output/filer window dimensions will cause incorrect calculation.
837       AffineMap minusOneMap =
838           AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 - 1}})
839               .front();
840       AffineMap plusOneMap =
841           AffineMap::inferFromExprList({ArrayRef<AffineExpr>{dim0 + 1}})
842               .front();
843       SmallVector<OpFoldResult> maxIndices =
844           llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) {
845             return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap,
846                                                  {ub});
847           }));
848       OpFoldResult maxIndex =
849           makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices);
850       OpFoldResult d =
851           makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex});
852 
853       // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
854       AffineMap minMap = AffineMap::inferFromExprList(
855                              {ArrayRef<AffineExpr>{dim1 - dim2, dim0}})
856                              .front();
857       size =
858           makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});
859     }
860     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
861     sliceParams.sizes.push_back(size);
862   }
863   return sliceParams;
864 }
865 
866 SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
867                                              ArrayRef<OpFoldResult> ivs,
868                                              ArrayRef<OpFoldResult> tileSizes) {
869   SmallVector<OpFoldResult> offsets;
870   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
871     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
872     bool isTiled = !isZero(tileSizes[idx]);
873     offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
874     LLVM_DEBUG(llvm::dbgs()
875                << "computeTileOffsets: " << offsets.back() << "\n");
876   }
877   return offsets;
878 }
879 
880 SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
881                                            ArrayRef<OpFoldResult> tileSizes,
882                                            ArrayRef<OpFoldResult> sizeBounds) {
883   SmallVector<OpFoldResult> sizes;
884   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
885     bool isTiled = !isZero(tileSizes[idx]);
886     // Before composing, we need to make range a closed interval.
887     OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
888     AffineExpr d0 = getAffineDimExpr(0, b.getContext());
889     IRRewriter rewriter(b);
890     sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size));
891     LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
892   }
893   return sizes;
894 }
895 
896 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
897   if (op.hasBufferSemantics())
898     return {};
899   return llvm::to_vector(
900       llvm::map_range(op.getDpsInitOperands(), [&](OpOperand *opOperand) {
901         return operands[opOperand->getOperandNumber()].getType();
902       }));
903 }
904 
905 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
906                                     LinalgOp op, ValueRange operands,
907                                     ValueRange results) {
908   if (op.hasBufferSemantics())
909     return {};
910   SmallVector<Value> tensorResults;
911   tensorResults.reserve(results.size());
912   // Insert a insert_slice for each output tensor.
913   unsigned resultIdx = 0;
914   for (OpOperand *opOperand : op.getDpsInitOperands()) {
915     // TODO: use an interface/adaptor to avoid leaking position in
916     // `tiledOperands`.
917     Value outputTensor = operands[opOperand->getOperandNumber()];
918     if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
919       Value inserted = builder.create<tensor::InsertSliceOp>(
920           loc, sliceOp.getSource().getType(), results[resultIdx],
921           sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
922           sliceOp.getStrides(), sliceOp.getStaticOffsets(),
923           sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
924       tensorResults.push_back(inserted);
925     } else {
926       tensorResults.push_back(results[resultIdx]);
927     }
928     ++resultIdx;
929   }
930   return tensorResults;
931 }
932 
933 SmallVector<std::optional<SliceParameters>>
934 computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
935                           ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
936                           ArrayRef<OpFoldResult> tileSizes,
937                           ArrayRef<OpFoldResult> sizeBounds,
938                           bool omitPartialTileCheck) {
939   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
940                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
941                            [](OpFoldResult v) { return !isZero(v); })) &&
942          "expected as many ivs as non-zero sizes");
943 
944   // Construct (potentially temporary) mins and maxes on which to apply maps
945   // that define tile subshapes.
946   SmallVector<OpFoldResult> lbs =
947       computeTileOffsets(builder, loc, ivs, tileSizes);
948   SmallVector<OpFoldResult> subShapeSizes =
949       computeTileSizes(builder, loc, tileSizes, sizeBounds);
950 
951   assert(static_cast<int64_t>(valuesToTile.size()) <=
952              linalgOp->getNumOperands() &&
953          "more value to tile than operands.");
954   SmallVector<std::optional<SliceParameters>> allSliceParams;
955   allSliceParams.reserve(valuesToTile.size());
956   for (auto [opOperand, val] :
957        llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
958     Value shapedOp = val;
959     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
960     AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
961     // Use `opOperand` as is if it is not tiled and not an output tensor. Having
962     // an extract/insert slice pair for all output tensors simplifies follow up
963     // transformations such as padding and bufferization since the
964     // extract/insert slice pairs make the accessed iteration argument
965     // subdomains explicit.
966 
967     Type operandType = opOperand.get().getType();
968     if (!isTiled(map, tileSizes) && !(operandType.isa<RankedTensorType>() &&
969                                       linalgOp.isDpsInit(&opOperand))) {
970       allSliceParams.push_back(std::nullopt);
971       LLVM_DEBUG(llvm::dbgs()
972                  << ": not tiled: use shape: " << operandType << "\n");
973       continue;
974     }
975     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
976 
977     allSliceParams.push_back(computeSliceParameters(
978         builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
979         omitPartialTileCheck));
980   }
981 
982   return allSliceParams;
983 }
984 
985 SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
986                                    LinalgOp linalgOp, ValueRange valuesToTile,
987                                    ArrayRef<OpFoldResult> ivs,
988                                    ArrayRef<OpFoldResult> tileSizes,
989                                    ArrayRef<OpFoldResult> sizeBounds,
990                                    bool omitPartialTileCheck) {
991   SmallVector<std::optional<SliceParameters>> allSliceParameter =
992       computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
993                                 tileSizes, sizeBounds, omitPartialTileCheck);
994   SmallVector<Value> tiledShapes;
995   for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
996     Value valueToTile = std::get<0>(item);
997     std::optional<SliceParameters> sliceParams = std::get<1>(item);
998     tiledShapes.push_back(
999         sliceParams.has_value()
1000             ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
1001             : valueToTile);
1002   }
1003   return tiledShapes;
1004 }
1005 
1006 void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
1007                    ArrayRef<OpFoldResult> offsets) {
1008   IRRewriter rewriter(b);
1009   offsetIndices(rewriter, linalgOp, offsets);
1010 }
1011 
1012 void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
1013                    ArrayRef<OpFoldResult> offsets) {
1014   if (!linalgOp.hasIndexSemantics())
1015     return;
1016 
1017   for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
1018     if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
1019       continue;
1020     OpBuilder::InsertionGuard guard(b);
1021     b.setInsertionPointAfter(indexOp);
1022     AffineExpr index, offset;
1023     bindDims(b.getContext(), index, offset);
1024     OpFoldResult applied = makeComposedFoldedAffineApply(
1025         b, indexOp.getLoc(), index + offset,
1026         {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
1027     Value materialized =
1028         getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
1029     b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) {
1030       return use.getOwner() != materialized.getDefiningOp();
1031     });
1032   }
1033 }
1034 
1035 /// Get the reassociation maps to fold the result of a extract_slice (or source
1036 /// of a insert_slice) operation with given offsets, and sizes to its
1037 /// rank-reduced version. This is only done for the cases where the size is 1
1038 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
1039 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
1040 /// potentially cannot be handled).
1041 std::optional<SmallVector<ReassociationIndices>>
1042 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
1043   SmallVector<ReassociationIndices> reassociation;
1044   ReassociationIndices curr;
1045   for (const auto &it : llvm::enumerate(mixedSizes)) {
1046     auto dim = it.index();
1047     auto size = it.value();
1048     curr.push_back(dim);
1049     auto attr = size.dyn_cast<Attribute>();
1050     if (attr && attr.cast<IntegerAttr>().getInt() == 1)
1051       continue;
1052     reassociation.emplace_back(ReassociationIndices{});
1053     std::swap(reassociation.back(), curr);
1054   }
1055   // When the reassociations are not empty, then fold the remaining
1056   // unit-dimensions into the last dimension.  If the reassociations so far is
1057   // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
1058   if (!curr.empty() && !reassociation.empty())
1059     reassociation.back().append(curr.begin(), curr.end());
1060   return reassociation;
1061 }
1062 
1063 /// Return the identity numeric value associated to the give op.
1064 std::optional<Attribute> getNeutralElement(Operation *op) {
1065   // Builder only used as helper for attribute creation.
1066   OpBuilder b(op->getContext());
1067   Type resultType = op->getResult(0).getType();
1068   if (auto floatType = resultType.dyn_cast<FloatType>()) {
1069     const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
1070     if (isa<arith::AddFOp>(op))
1071       return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
1072     if (isa<arith::MulFOp>(op))
1073       return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
1074     if (isa<arith::MaxFOp>(op))
1075       return b.getFloatAttr(resultType,
1076                             llvm::APFloat::getInf(semantic, /*Negative=*/true));
1077     if (isa<arith::MinFOp>(op))
1078       return b.getFloatAttr(
1079           resultType, llvm::APFloat::getInf(semantic, /*Negative=*/false));
1080     return std::nullopt;
1081   }
1082   if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
1083     return b.getIntegerAttr(resultType, 0);
1084   if (isa<arith::AndIOp>(op))
1085     return b.getIntegerAttr(resultType, -1);
1086   if (isa<arith::MaxSIOp>(op))
1087     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
1088   if (isa<arith::MinSIOp>(op))
1089     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
1090   if (isa<arith::MulIOp>(op))
1091     return b.getIntegerAttr(resultType, 1);
1092   return std::nullopt;
1093 }
1094 
1095 } // namespace linalg
1096 } // namespace mlir
1097