xref: /llvm-project/mlir/lib/Dialect/Linalg/Utils/Utils.cpp (revision 59a92019fbc0a67ec82a903a4f6167ad45545a7f)
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Utils/Utils.h"
14 
15 #include "mlir/Analysis/SliceAnalysis.h"
16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Affine/LoopUtils.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/Arith/Utils/Utils.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Linalg/IR/Linalg.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h"
27 #include "mlir/Dialect/Tensor/Utils/Utils.h"
28 #include "mlir/Dialect/Utils/IndexingUtils.h"
29 #include "mlir/Dialect/Utils/StaticValueUtils.h"
30 #include "mlir/IR/AffineExpr.h"
31 #include "mlir/IR/AffineExprVisitor.h"
32 #include "mlir/IR/AffineMap.h"
33 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/OpImplementation.h"
35 #include "mlir/Pass/Pass.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/Debug.h"
38 #include <optional>
39 
40 #define DEBUG_TYPE "linalg-utils"
41 
42 using namespace mlir;
43 using namespace presburger;
44 using namespace mlir::affine;
45 using namespace mlir::linalg;
46 using namespace mlir::scf;
47 
48 namespace {
49 
50 // Helper visitor to determine whether an AffineExpr is tiled.
51 // This is achieved by traversing every AffineDimExpr with position `pos` and
52 // checking whether the corresponding `tileSizes[pos]` is non-zero.
53 // This also enforces only positive coefficients occur in multiplications.
54 //
55 // Example:
56 //   `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
57 //
58 struct TileCheck : public AffineExprVisitor<TileCheck> {
59   TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
60 
61   void visitDimExpr(AffineDimExpr expr) {
62     isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
63   }
64   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
65     visit(expr.getLHS());
66     visit(expr.getRHS());
67     if (expr.getKind() == mlir::AffineExprKind::Mul)
68       assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 &&
69              "nonpositive multiplying coefficient");
70   }
71   bool isTiled = false;
72   ArrayRef<OpFoldResult> tileSizes;
73 };
74 
75 } // namespace
76 
77 static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
78   if (!expr)
79     return false;
80   TileCheck t(tileSizes);
81   t.visit(expr);
82   return t.isTiled;
83 }
84 
85 // Checks whether the `map  varies with respect to a non-zero `tileSize`.
86 static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
87   if (!map)
88     return false;
89   for (unsigned r = 0; r < map.getNumResults(); ++r)
90     if (isTiled(map.getResult(r), tileSizes))
91       return true;
92   return false;
93 }
94 
95 std::optional<RegionMatcher::BinaryOpKind>
96 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
97   auto &region = op.getRegion();
98   if (!llvm::hasSingleElement(region))
99     return std::nullopt;
100 
101   Block &block = region.front();
102   if (block.getNumArguments() != 2 ||
103       !block.getArgument(0).getType().isSignlessIntOrFloat() ||
104       !block.getArgument(1).getType().isSignlessIntOrFloat())
105     return std::nullopt;
106 
107   auto &ops = block.getOperations();
108   if (!llvm::hasSingleElement(block.without_terminator()))
109     return std::nullopt;
110 
111   using mlir::matchers::m_Val;
112   auto a = m_Val(block.getArgument(0));
113   auto b = m_Val(block.getArgument(1));
114 
115   auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b));
116   if (addPattern.match(&ops.back()))
117     return BinaryOpKind::IAdd;
118 
119   return std::nullopt;
120 }
121 
122 /// Explicit instantiation of loop nest generator for different loop types.
123 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
124 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
125 template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
126 
127 /// Given a list of subview ranges, extract individual values for lower, upper
128 /// bounds and steps and put them into the corresponding vectors.
129 static void unpackRanges(OpBuilder &builder, Location loc,
130                          ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
131                          SmallVectorImpl<Value> &ubs,
132                          SmallVectorImpl<Value> &steps) {
133   for (Range range : ranges) {
134     lbs.emplace_back(
135         getValueOrCreateConstantIndexOp(builder, loc, range.offset));
136     ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size));
137     steps.emplace_back(
138         getValueOrCreateConstantIndexOp(builder, loc, range.stride));
139   }
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // General utilities
144 //===----------------------------------------------------------------------===//
145 
146 namespace mlir {
147 namespace linalg {
148 
149 bool allIndexingsAreProjectedPermutation(LinalgOp op) {
150   return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
151     return m.isProjectedPermutation(/*allowZeroInResults=*/true);
152   });
153 }
154 
155 bool hasOnlyScalarElementwiseOp(Region &r) {
156   if (!llvm::hasSingleElement(r))
157     return false;
158   for (Operation &op : r.front()) {
159     if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
160               linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
161           OpTrait::hasElementwiseMappableTraits(&op)) ||
162         llvm::any_of(op.getResultTypes(),
163                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
164       return false;
165   }
166   return true;
167 }
168 
169 bool isElementwise(LinalgOp op) {
170   if (op.getNumLoops() != op.getNumParallelLoops())
171     return false;
172 
173   if (!allIndexingsAreProjectedPermutation(op))
174     return false;
175 
176   // TODO: relax the restrictions on indexing map.
177   for (OpOperand &opOperand : op.getDpsInitsMutable()) {
178     if (!op.getMatchingIndexingMap(&opOperand).isPermutation())
179       return false;
180   }
181   return hasOnlyScalarElementwiseOp(op->getRegion(0));
182 }
183 
184 bool isParallelIterator(utils::IteratorType iteratorType) {
185   return iteratorType == utils::IteratorType::parallel;
186 }
187 
188 bool isReductionIterator(utils::IteratorType iteratorType) {
189   return iteratorType == utils::IteratorType::reduction;
190 }
191 
192 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
193                             Value source, Value pad, bool nofold) {
194   // Exit if `source` is not defined by an ExtractSliceOp.
195   auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
196   if (!sliceOp)
197     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
198 
199   // Search the `source` use-def chain for padded LinalgOps.
200   Value current = sliceOp.getSource();
201   while (current) {
202     auto linalgOp = current.getDefiningOp<LinalgOp>();
203     if (!linalgOp)
204       break;
205     OpResult opResult = cast<OpResult>(current);
206     current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
207   }
208   auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
209 
210   // Exit if the search fails to match a tensor::PadOp at the end of the matched
211   // LinalgOp sequence.
212   if (!padOp)
213     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
214 
215   // Exit if the padded result type does not match.
216   if (sliceOp.getSource().getType() != type)
217     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
218 
219   // Exit if the LinalgOps are not high padded.
220   if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
221         return getConstantIntValue(ofr) != static_cast<int64_t>(0);
222       }))
223     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
224 
225   // Exit if `padOpSliceOp`, which defines the slice used by
226   // `padOp`, is rank-reducing.
227   auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
228   if (!padOpSliceOp ||
229       sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
230     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
231 
232   // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
233   // of the slice padded by `padOp`.
234   if (llvm::any_of(
235           llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
236           [](std::tuple<OpFoldResult, OpFoldResult> it) {
237             return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
238           }))
239     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
240 
241   // Exit if the padding values do not match.
242   Attribute padOpPadAttr, padAttr;
243   Value padOpPad = padOp.getConstantPaddingValue();
244   if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) ||
245       !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr)
246     return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
247 
248   // Return the padded result if the padding values and sizes match.
249   return sliceOp.getSource();
250 }
251 
252 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
253                           Value outputTensor,
254                           ArrayRef<int64_t> transposeVector) {
255   auto resultTensorType = cast<RankedTensorType>(outputTensor.getType());
256   Type elementType = resultTensorType.getElementType();
257 
258   assert(isPermutationVector(transposeVector) &&
259          "expect transpose vector to be a permutation");
260   assert(transposeVector.size() ==
261              static_cast<size_t>(resultTensorType.getRank()) &&
262          "expect transpose vector size to match result tensor rank");
263 
264   // Compute the transpose and the indentity indexing maps.
265   SmallVector<AffineMap> indexingMaps = {
266       inversePermutation(AffineMap::getPermutationMap(
267           SmallVector<unsigned>(transposeVector.begin(), transposeVector.end()),
268           b.getContext())),
269       AffineMap::getMultiDimIdentityMap(transposeVector.size(),
270                                         b.getContext())};
271   SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(),
272                                                  utils::IteratorType::parallel);
273 
274   // Create a GenericOp to transpose `inputTensor` into `outputTensor`.
275   auto transposeOp =
276       b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor,
277                           indexingMaps, iteratorTypes);
278 
279   // Create the body of the transpose operation.
280   OpBuilder::InsertionGuard g(b);
281   Region &body = transposeOp.getRegion();
282   Block *bodyBlock = b.createBlock(&body, /*insertPt=*/{},
283                                    {elementType, elementType}, {loc, loc});
284   b.create<YieldOp>(loc, bodyBlock->getArgument(0));
285   return transposeOp;
286 }
287 
288 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
289   auto memrefTypeTo = cast<MemRefType>(to.getType());
290 #ifndef NDEBUG
291   auto memrefTypeFrom = cast<MemRefType>(from.getType());
292   assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
293          "`from` and `to` memref must have the same rank");
294 #endif // NDEBUG
295 
296   AffineMap id =
297       AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
298   SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
299                                                  utils::IteratorType::parallel);
300   return b.create<linalg::GenericOp>(
301       loc,
302       /*inputs=*/from,
303       /*outputs=*/to,
304       /*indexingMaps=*/llvm::ArrayRef({id, id}),
305       /*iteratorTypes=*/iteratorTypes,
306       [](OpBuilder &b, Location loc, ValueRange args) {
307         b.create<linalg::YieldOp>(loc, args.front());
308       });
309 }
310 
311 /// Specialization to build an scf "for" nest.
312 template <>
313 void GenerateLoopNest<scf::ForOp>::doit(
314     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
315     ArrayRef<utils::IteratorType> iteratorTypes,
316     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
317                                   ValueRange)>
318         bodyBuilderFn,
319     ArrayRef<linalg::ProcInfo> procInfo) {
320   assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
321          "expected as many entries for proc info as number of loops, even if "
322          "they are null entries");
323   SmallVector<Value> iterArgInitValues;
324   if (!linalgOp.hasPureBufferSemantics())
325     llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
326   SmallVector<Value, 4> lbs, ubs, steps;
327   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
328   LoopNest loopNest = mlir::scf::buildLoopNest(
329       b, loc, lbs, ubs, steps, iterArgInitValues,
330       [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
331         assert(iterArgs.size() == iterArgInitValues.size() &&
332                "expect the number of output tensors and iter args to match");
333         SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
334         if (!iterArgs.empty()) {
335           operandValuesToUse = linalgOp.getDpsInputs();
336           operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
337         }
338         return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
339       });
340 
341   if (loopNest.loops.empty() || procInfo.empty())
342     return;
343 
344   // Filter out scf.for loops that were created out of parallel dimensions.
345   for (const auto &loop : llvm::enumerate(loopNest.loops)) {
346     if (procInfo[loop.index()].distributionMethod ==
347         DistributionMethod::Cyclic) {
348       mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
349                             procInfo[loop.index()].nprocs);
350     }
351   }
352 }
353 
354 /// Specialization to build affine "for" nest.
355 template <>
356 void GenerateLoopNest<AffineForOp>::doit(
357     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
358     ArrayRef<utils::IteratorType> iteratorTypes,
359     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
360                                   ValueRange)>
361         bodyBuilderFn,
362     ArrayRef<linalg::ProcInfo> /*procInfo*/) {
363   SmallVector<Value> iterArgInitValues;
364   if (!linalgOp.hasPureBufferSemantics())
365     llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
366   assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
367   SmallVector<Value, 4> lbs, ubs, steps;
368   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
369 
370   // Affine loops require constant steps.
371   SmallVector<int64_t, 4> constantSteps;
372   constantSteps.reserve(steps.size());
373   for (Value v : steps) {
374     auto constVal = getConstantIntValue(v);
375     assert(constVal.has_value() && "Affine loops require constant steps");
376     constantSteps.push_back(constVal.value());
377   }
378 
379   affine::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
380                               [&](OpBuilder &b, Location loc, ValueRange ivs) {
381                                 bodyBuilderFn(b, loc, ivs,
382                                               linalgOp->getOperands());
383                               });
384 }
385 
386 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
387 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId,
388                                        Value nprocs, Value &lb, Value &ub,
389                                        Value &step) {
390   AffineExpr d0, d1;
391   bindDims(b.getContext(), d0, d1);
392   AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
393   lb =
394       affine::makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
395   step = affine::makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
396 }
397 
398 /// Generates a loop nest consisting of scf.parallel and scf.for, depending
399 /// on the `iteratorTypes.` Consecutive parallel loops create a single
400 /// scf.parallel operation; each sequential loop creates a new scf.for
401 /// operation. The body of the innermost loop is populated by
402 /// `bodyBuilderFn` that accepts a range of induction variables for all
403 /// loops. `ivStorage` is used to store the partial list of induction
404 /// variables.
405 // TODO: this function can be made iterative instead. However, it
406 // will have at most as many recursive calls as nested loops, which rarely
407 // exceeds 10.
408 static void generateParallelLoopNest(
409     OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
410     ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
411     ArrayRef<linalg::ProcInfo> procInfo,
412     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
413     SmallVectorImpl<Value> &ivStorage) {
414   assert(lbs.size() == ubs.size());
415   assert(lbs.size() == steps.size());
416   assert(lbs.size() == iteratorTypes.size());
417   assert(procInfo.empty() || (lbs.size() == procInfo.size()));
418 
419   // If there are no (more) loops to be generated, generate the body and be
420   // done with it.
421   if (iteratorTypes.empty()) {
422     bodyBuilderFn(b, loc, ivStorage);
423     return;
424   }
425 
426   // If there are no outer parallel loops, generate one sequential loop and
427   // recurse.
428   if (!isParallelIterator(iteratorTypes.front())) {
429     LoopNest singleLoop = buildLoopNest(
430         b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
431         [&](OpBuilder &b, Location loc, ValueRange ivs) {
432           ivStorage.append(ivs.begin(), ivs.end());
433           generateParallelLoopNest(
434               b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
435               iteratorTypes.drop_front(),
436               procInfo.empty() ? procInfo : procInfo.drop_front(),
437               bodyBuilderFn, ivStorage);
438         });
439     return;
440   }
441 
442   unsigned nLoops = iteratorTypes.size();
443   unsigned numProcessed = 0;
444   DistributionMethod distributionMethod = DistributionMethod::None;
445   if (procInfo.empty()) {
446     numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size();
447   } else {
448     distributionMethod = procInfo.front().distributionMethod;
449     numProcessed =
450         nLoops - procInfo
451                      .drop_while([&](linalg::ProcInfo p) {
452                        return p.distributionMethod == distributionMethod;
453                      })
454                      .size();
455   }
456 
457   auto remainderProcInfo =
458       procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
459   switch (distributionMethod) {
460   case DistributionMethod::None: {
461     // Generate a single parallel loop-nest operation for all outermost
462     // parallel loops and recurse.
463     b.create<scf::ParallelOp>(
464         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
465         steps.take_front(numProcessed),
466         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
467           ivStorage.append(localIvs.begin(), localIvs.end());
468           generateParallelLoopNest(
469               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
470               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
471               iteratorTypes.drop_front(numProcessed), remainderProcInfo,
472               bodyBuilderFn, ivStorage);
473         });
474     return;
475   }
476   case DistributionMethod::Cyclic: {
477     // Generate a single parallel loop-nest operation for all outermost
478     // parallel loops and recurse.
479     b.create<scf::ParallelOp>(
480         loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
481         steps.take_front(numProcessed),
482         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
483           ivStorage.append(localIvs.begin(), localIvs.end());
484           generateParallelLoopNest(
485               nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
486               ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
487               iteratorTypes.drop_front(numProcessed), remainderProcInfo,
488               bodyBuilderFn, ivStorage);
489         });
490     return;
491   }
492   case DistributionMethod::CyclicNumProcsGeNumIters: {
493     // Check (for the processed loops) that the iteration is in-bounds.
494     ArithBuilder ab(b, loc);
495     Value cond = ab.slt(lbs[0], ubs[0]);
496     for (unsigned i = 1; i < numProcessed; ++i)
497       cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
498     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
499     b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) {
500       generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed),
501                                ubs.drop_front(numProcessed),
502                                steps.drop_front(numProcessed),
503                                iteratorTypes.drop_front(numProcessed),
504                                remainderProcInfo, bodyBuilderFn, ivStorage);
505       b.create<scf::YieldOp>(loc, ValueRange{});
506     });
507     return;
508   }
509   case DistributionMethod::CyclicNumProcsEqNumIters:
510     // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
511     // with inner loop generation.
512     ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
513     generateParallelLoopNest(
514         b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
515         steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
516         remainderProcInfo, bodyBuilderFn, ivStorage);
517     return;
518   }
519 }
520 
521 /// Specialization for generating a mix of parallel and sequential scf loops.
522 template <>
523 void GenerateLoopNest<scf::ParallelOp>::doit(
524     OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
525     ArrayRef<utils::IteratorType> iteratorTypes,
526     function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange,
527                                   ValueRange)>
528         bodyBuilderFn,
529     ArrayRef<linalg::ProcInfo> procInfo) {
530   SmallVector<Value> iterArgInitValues;
531   if (!linalgOp.hasPureBufferSemantics())
532     llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
533   assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
534   // This function may be passed more iterator types than ranges.
535   assert(iteratorTypes.size() >= loopRanges.size() &&
536          "expected iterator type for all ranges");
537   assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
538          "expected proc information for all loops when present");
539   iteratorTypes = iteratorTypes.take_front(loopRanges.size());
540   SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
541   unsigned numLoops = iteratorTypes.size();
542   ivs.reserve(numLoops);
543   lbsStorage.reserve(numLoops);
544   ubsStorage.reserve(numLoops);
545   stepsStorage.reserve(numLoops);
546 
547   // Get the loop lb, ub, and step.
548   unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
549 
550   // Modify the lb, ub, and step based on the distribution options.
551   for (const auto &it : llvm::enumerate(procInfo)) {
552     if (it.value().distributionMethod != linalg::DistributionMethod::None) {
553       updateBoundsForCyclicDistribution(
554           b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
555           ubsStorage[it.index()], stepsStorage[it.index()]);
556     }
557   }
558   ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
559   generateParallelLoopNest(
560       b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
561       [&](OpBuilder &b, Location loc, ValueRange ivs) {
562         bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
563       },
564       ivs);
565 
566   assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
567 }
568 
569 static Value materializeTiledShape(OpBuilder &builder, Location loc,
570                                    Value valueToTile,
571                                    const SliceParameters &sliceParams) {
572   auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
573   auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
574                       .Case([&](MemRefType) {
575                         return builder.create<memref::SubViewOp>(
576                             loc, valueToTile, sliceParams.offsets,
577                             sliceParams.sizes, sliceParams.strides);
578                       })
579                       .Case([&](RankedTensorType) {
580                         return builder.create<tensor::ExtractSliceOp>(
581                             loc, valueToTile, sliceParams.offsets,
582                             sliceParams.sizes, sliceParams.strides);
583                       })
584                       .Default([](ShapedType) -> Operation * {
585                         llvm_unreachable("Unexpected shaped type");
586                       });
587   return sliceOp->getResult(0);
588 }
589 
590 Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
591                      ArrayRef<OpFoldResult> tileSizes, AffineMap map,
592                      ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
593                      ArrayRef<OpFoldResult> subShapeSizes,
594                      bool omitPartialTileCheck) {
595   SliceParameters sliceParams =
596       computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
597                              ubs, subShapeSizes, omitPartialTileCheck);
598   return materializeTiledShape(builder, loc, valueToTile, sliceParams);
599 }
600 
601 SliceParameters
602 computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
603                        ArrayRef<OpFoldResult> tileSizes, AffineMap map,
604                        ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
605                        ArrayRef<OpFoldResult> subShapeSizes,
606                        bool omitPartialTileCheck) {
607   auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
608   assert(shapedType && "only shaped types can be tiled");
609   ArrayRef<int64_t> shape = shapedType.getShape();
610   int64_t rank = shapedType.getRank();
611 
612   // Compute offsets/sizes/strides for the tile.
613   SliceParameters sliceParams;
614   sliceParams.offsets.reserve(rank);
615   sliceParams.sizes.reserve(rank);
616   sliceParams.strides.reserve(rank);
617   for (unsigned r = 0; r < rank; ++r) {
618     LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
619     if (!isTiled(map.getSubMap({r}), tileSizes)) {
620       sliceParams.offsets.push_back(builder.getIndexAttr(0));
621       OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
622       sliceParams.sizes.push_back(dim);
623       sliceParams.strides.push_back(builder.getIndexAttr(1));
624       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
625       continue;
626     }
627     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
628 
629     // Tiling creates a new slice at the proper index, the slice step is 1
630     // (i.e. the op does not subsample, stepping occurs in the loop).
631     auto m = map.getSubMap({r});
632     LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
633     IRRewriter rewriter(builder);
634     OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs);
635     sliceParams.offsets.push_back(offset);
636     OpFoldResult closedIntSize =
637         makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
638     // Resulting size needs to be made half open interval again.
639     AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
640     OpFoldResult size =
641         makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize);
642     LLVM_DEBUG(llvm::dbgs()
643                << "computeSliceParameters: raw size: " << size << "\n");
644     LLVM_DEBUG(llvm::dbgs()
645                << "computeSliceParameters: new offset: " << offset << "\n");
646     sliceParams.strides.push_back(builder.getIndexAttr(1));
647 
648     if (omitPartialTileCheck) {
649       // We statically know that the partial/boundary tile condition is
650       // unnecessary.
651       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
652       sliceParams.sizes.push_back(size);
653       continue;
654     }
655 
656     // The size of the subview / extract_slice should be trimmed to avoid
657     // out-of-bounds accesses, unless:
658     // a. We statically know the subshape size divides the shape size evenly.
659     // b. The subshape size is 1. According to the way the loops are set up,
660     //    tensors with "0" dimensions would never be constructed.
661     int64_t shapeSize = shape[r];
662     std::optional<int64_t> sizeCst = getConstantIntValue(size);
663     auto hasTileSizeOne = sizeCst && *sizeCst == 1;
664     auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
665                          ((shapeSize % *sizeCst) == 0);
666     if (!hasTileSizeOne && !dividesEvenly) {
667       LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
668                               << ", size: " << size
669                               << ": make sure in bound with affine.min\n");
670 
671       AffineExpr dim0, dim1, dim2;
672       MLIRContext *context = builder.getContext();
673       bindDims(context, dim0, dim1, dim2);
674 
675       // Get the dimension size for this dimension. We need to first calculate
676       // the max index and then plus one. This is important because for
677       // convolution ops, we have its input window dimension's affine map of the
678       // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
679       // dimension and `s0` is stride. Directly use the dimension size of
680       // output/filer window dimensions will cause incorrect calculation.
681       AffineMap minusOneMap = AffineMap::inferFromExprList(
682                                   {ArrayRef<AffineExpr>{dim0 - 1}}, context)
683                                   .front();
684       AffineMap plusOneMap = AffineMap::inferFromExprList(
685                                  {ArrayRef<AffineExpr>{dim0 + 1}}, context)
686                                  .front();
687       SmallVector<OpFoldResult> maxIndices =
688           llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) {
689             return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap,
690                                                  {ub});
691           }));
692       OpFoldResult maxIndex =
693           makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices);
694       OpFoldResult d =
695           makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex});
696 
697       // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
698       AffineMap minMap = AffineMap::inferFromExprList(
699                              {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context)
700                              .front();
701       size =
702           makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});
703     }
704     LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
705     sliceParams.sizes.push_back(size);
706   }
707   return sliceParams;
708 }
709 
710 SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
711                                              ArrayRef<OpFoldResult> ivs,
712                                              ArrayRef<OpFoldResult> tileSizes) {
713   SmallVector<OpFoldResult> offsets;
714   for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
715     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
716     bool isTiled = !isZeroIndex(tileSizes[idx]);
717     offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
718     LLVM_DEBUG(llvm::dbgs()
719                << "computeTileOffsets: " << offsets.back() << "\n");
720   }
721   return offsets;
722 }
723 
724 SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
725                                            ArrayRef<OpFoldResult> tileSizes,
726                                            ArrayRef<OpFoldResult> sizeBounds) {
727   SmallVector<OpFoldResult> sizes;
728   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
729     bool isTiled = !isZeroIndex(tileSizes[idx]);
730     // Before composing, we need to make range a closed interval.
731     OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
732     AffineExpr d0 = getAffineDimExpr(0, b.getContext());
733     IRRewriter rewriter(b);
734     sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size));
735     LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
736   }
737   return sizes;
738 }
739 
740 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
741   if (op.hasPureBufferSemantics())
742     return {};
743   return llvm::to_vector(
744       llvm::map_range(op.getDpsInitsMutable(), [&](OpOperand &opOperand) {
745         return operands[opOperand.getOperandNumber()].getType();
746       }));
747 }
748 
749 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
750                                     LinalgOp op, ValueRange operands,
751                                     ValueRange results) {
752   if (op.hasPureBufferSemantics())
753     return {};
754   SmallVector<Value> tensorResults;
755   tensorResults.reserve(results.size());
756   // Insert a insert_slice for each output tensor.
757   unsigned resultIdx = 0;
758   for (OpOperand &opOperand : op.getDpsInitsMutable()) {
759     // TODO: use an interface/adaptor to avoid leaking position in
760     // `tiledOperands`.
761     Value outputTensor = operands[opOperand.getOperandNumber()];
762     if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
763       Value inserted = builder.create<tensor::InsertSliceOp>(
764           loc, sliceOp.getSource().getType(), results[resultIdx],
765           sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
766           sliceOp.getStrides(), sliceOp.getStaticOffsets(),
767           sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
768       tensorResults.push_back(inserted);
769     } else {
770       tensorResults.push_back(results[resultIdx]);
771     }
772     ++resultIdx;
773   }
774   return tensorResults;
775 }
776 
777 SmallVector<std::optional<SliceParameters>>
778 computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
779                           ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
780                           ArrayRef<OpFoldResult> tileSizes,
781                           ArrayRef<OpFoldResult> sizeBounds,
782                           bool omitPartialTileCheck) {
783   assert(ivs.size() == static_cast<size_t>(llvm::count_if(
784                            llvm::make_range(tileSizes.begin(), tileSizes.end()),
785                            [](OpFoldResult v) { return !isZeroIndex(v); })) &&
786          "expected as many ivs as non-zero sizes");
787 
788   // Construct (potentially temporary) mins and maxes on which to apply maps
789   // that define tile subshapes.
790   SmallVector<OpFoldResult> lbs =
791       computeTileOffsets(builder, loc, ivs, tileSizes);
792   SmallVector<OpFoldResult> subShapeSizes =
793       computeTileSizes(builder, loc, tileSizes, sizeBounds);
794 
795   assert(static_cast<int64_t>(valuesToTile.size()) <=
796              linalgOp->getNumOperands() &&
797          "more value to tile than operands.");
798   SmallVector<std::optional<SliceParameters>> allSliceParams;
799   allSliceParams.reserve(valuesToTile.size());
800   for (auto [opOperand, val] :
801        llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
802     Value shapedOp = val;
803     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
804     AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
805     // Use `opOperand` as is if it is not tiled and not an output tensor. Having
806     // an extract/insert slice pair for all output tensors simplifies follow up
807     // transformations such as padding and bufferization since the
808     // extract/insert slice pairs make the accessed iteration argument
809     // subdomains explicit.
810 
811     Type operandType = opOperand.get().getType();
812     if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
813                                       linalgOp.isDpsInit(&opOperand))) {
814       allSliceParams.push_back(std::nullopt);
815       LLVM_DEBUG(llvm::dbgs()
816                  << ": not tiled: use shape: " << operandType << "\n");
817       continue;
818     }
819     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
820 
821     allSliceParams.push_back(computeSliceParameters(
822         builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
823         omitPartialTileCheck));
824   }
825 
826   return allSliceParams;
827 }
828 
829 SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
830                                    LinalgOp linalgOp, ValueRange valuesToTile,
831                                    ArrayRef<OpFoldResult> ivs,
832                                    ArrayRef<OpFoldResult> tileSizes,
833                                    ArrayRef<OpFoldResult> sizeBounds,
834                                    bool omitPartialTileCheck) {
835   SmallVector<std::optional<SliceParameters>> allSliceParameter =
836       computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
837                                 tileSizes, sizeBounds, omitPartialTileCheck);
838   SmallVector<Value> tiledShapes;
839   for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
840     Value valueToTile = std::get<0>(item);
841     std::optional<SliceParameters> sliceParams = std::get<1>(item);
842     tiledShapes.push_back(
843         sliceParams.has_value()
844             ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
845             : valueToTile);
846   }
847   return tiledShapes;
848 }
849 
850 void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
851                    ArrayRef<OpFoldResult> offsets) {
852   IRRewriter rewriter(b);
853   offsetIndices(rewriter, linalgOp, offsets);
854 }
855 
856 void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
857                    ArrayRef<OpFoldResult> offsets) {
858   if (!linalgOp.hasIndexSemantics())
859     return;
860 
861   for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
862     if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
863       continue;
864     OpBuilder::InsertionGuard guard(b);
865     b.setInsertionPointAfter(indexOp);
866     AffineExpr index, offset;
867     bindDims(b.getContext(), index, offset);
868     OpFoldResult applied = makeComposedFoldedAffineApply(
869         b, indexOp.getLoc(), index + offset,
870         {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
871     Value materialized =
872         getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
873     b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
874       return use.getOwner() != materialized.getDefiningOp();
875     });
876   }
877 }
878 
879 /// Get the reassociation maps to fold the result of a extract_slice (or source
880 /// of a insert_slice) operation with given offsets, and sizes to its
881 /// rank-reduced version. This is only done for the cases where the size is 1
882 /// and offset is 0. Strictly speaking the offset 0 is not required in general,
883 /// but non-zero offsets are not handled by SPIR-V backend at this point (and
884 /// potentially cannot be handled).
885 std::optional<SmallVector<ReassociationIndices>>
886 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
887   SmallVector<ReassociationIndices> reassociation;
888   ReassociationIndices curr;
889   for (const auto &it : llvm::enumerate(mixedSizes)) {
890     auto dim = it.index();
891     auto size = it.value();
892     curr.push_back(dim);
893     auto attr = llvm::dyn_cast_if_present<Attribute>(size);
894     if (attr && cast<IntegerAttr>(attr).getInt() == 1)
895       continue;
896     reassociation.emplace_back(ReassociationIndices{});
897     std::swap(reassociation.back(), curr);
898   }
899   // When the reassociations are not empty, then fold the remaining
900   // unit-dimensions into the last dimension.  If the reassociations so far is
901   // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
902   if (!curr.empty() && !reassociation.empty())
903     reassociation.back().append(curr.begin(), curr.end());
904   return reassociation;
905 }
906 
907 } // namespace linalg
908 } // namespace mlir
909