xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (revision 4f279a570110e3d688356a327637c57071f4b13b)
1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Tiling pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Passes.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Affine/LoopUtils.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/Linalg/IR/Linalg.h"
21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
24 #include "mlir/Dialect/Tensor/IR/Tensor.h"
25 #include "mlir/Dialect/Utils/IndexingUtils.h"
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/ValueRange.h"
30 #include "mlir/Transforms/FoldUtils.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/Support/CommandLine.h"
34 #include <utility>
35 
36 namespace mlir {
37 #define GEN_PASS_DEF_LINALGTILINGPASS
38 #include "mlir/Dialect/Linalg/Passes.h.inc"
39 } // namespace mlir
40 
41 using namespace mlir;
42 using namespace mlir::affine;
43 using namespace mlir::linalg;
44 using namespace mlir::scf;
45 
46 #define DEBUG_TYPE "linalg-tiling"
47 
48 std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
49 mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
50                                   ArrayRef<OpFoldResult> allShapeSizes,
51                                   ArrayRef<OpFoldResult> allTileSizes) {
52   assert(allTileSizes.size() == map.getNumResults());
53   // Apply `map` to get shape sizes in loop order.
54   SmallVector<OpFoldResult> shapeSizes =
55       makeComposedFoldedMultiResultAffineApply(b, loc, map, allShapeSizes);
56   SmallVector<OpFoldResult> tileSizes(allTileSizes);
57 
58   // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
59   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
60   for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
61     if (getConstantIntValue(tileSizes[idx - zerosCount]) ==
62         static_cast<int64_t>(0)) {
63       shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
64       tileSizes.erase(tileSizes.begin() + idx - zerosCount);
65       ++zerosCount;
66       continue;
67     }
68     loopIndexToRangeIndex[idx] = idx - zerosCount;
69   }
70 
71   // Create a new range with the applied tile sizes.
72   SmallVector<Range, 4> res;
73   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
74     res.push_back(Range{b.getIndexAttr(0), shapeSizes[idx], tileSizes[idx]});
75   return std::make_tuple(res, loopIndexToRangeIndex);
76 }
77 
78 void mlir::linalg::transformIndexOps(
79     RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
80     const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
81   SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
82   for (auto en : enumerate(allIvs)) {
83     auto rangeIndex = loopIndexToRangeIndex.find(en.index());
84     if (rangeIndex == loopIndexToRangeIndex.end())
85       continue;
86     en.value() = ivs[rangeIndex->second];
87   }
88   offsetIndices(b, op, getAsOpFoldResult(allIvs));
89 }
90 
91 /// Asserts that the given index-typed value is strictly positive. If the value
92 /// is an attribute, asserts at compile time, otherwise emits an assertion
93 /// checked at runtime.
94 static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
95                                          OpFoldResult value) {
96   if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
97     assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
98            "expected strictly positive tile size and divisor");
99     return;
100   }
101 
102   Value zero = b.create<arith::ConstantIndexOp>(0);
103   Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt,
104                                             cast<Value>(value), zero);
105   b.create<cf::AssertOp>(
106       condition,
107       b.getStringAttr("expected strictly positive tile size and divisor"));
108 }
109 
110 FailureOr<StaticContinuousTileSizeSpecification>
111 mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
112                                                unsigned dimension,
113                                                unsigned targetSize) {
114 
115   assert(!op.hasDynamicShape() &&
116          "cannot compute static multi-tile sizes for an op with dynamic shape");
117   assert(targetSize > 0 && "target size must be non-negative");
118   assert(dimension < op.getNumLoops() && "dimension overflow");
119 
120   StaticContinuousTileSizeSpecification spec;
121   int64_t loopRange = op.getStaticLoopRanges()[dimension];
122   int64_t tripCount = loopRange / targetSize;
123 
124   unsigned tileSize = targetSize;
125 
126   spec.tileSizes.push_back(tileSize);
127   spec.tripCounts.push_back(tripCount);
128 
129   int64_t remainderChunk = loopRange % targetSize;
130 
131   while (tileSize > 1 && remainderChunk != 0) {
132 
133     uint64_t maxPower = llvm::bit_floor(tileSize);
134     tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
135 
136     tripCount = remainderChunk / tileSize;
137 
138     if (tripCount > 0) {
139       spec.tileSizes.push_back(tileSize);
140       spec.tripCounts.push_back(tripCount);
141     }
142 
143     remainderChunk = remainderChunk % tileSize;
144   }
145 
146   auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
147                             SmallVector<int64_t> tripCounts,
148                             int64_t range) -> bool {
149     int64_t computedRange = 0;
150     for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
151       computedRange += tileSize * tripCount;
152     return range == computedRange;
153   };
154 
155   if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
156     return failure();
157 
158   return spec;
159 }
160 
161 FailureOr<ContinuousTileSizeSpecification>
162 mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
163                                          unsigned dimension,
164                                          OpFoldResult targetSize,
165                                          bool emitAssertions) {
166 
167   SmallVector<Range> loopRanges = op.getIterationDomain(builder);
168   unsigned numLoops = loopRanges.size();
169 
170   // Bail out on dimension overflow.
171   if (dimension >= numLoops)
172     return failure();
173 
174   // The code below works only on values.
175   Location loc = op->getLoc();
176   ImplicitLocOpBuilder b(loc, builder);
177   if (emitAssertions) {
178     emitIsPositiveIndexAssertion(b, targetSize);
179   }
180   Value targetSizeValue =
181       getValueOrCreateConstantIndexOp(builder, loc, targetSize);
182 
183   // Find the trip count of the iteration space dimension for which the tile
184   // sizes are computed.
185   Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
186                                                     loopRanges[dimension].size);
187   ContinuousTileSizeSpecification spec;
188 
189   // Compute the tile sizes and the respective numbers of tiles.
190   AffineExpr s0 = b.getAffineSymbolExpr(0);
191   AffineExpr s1 = b.getAffineSymbolExpr(1);
192   auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
193     return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
194   };
195 
196   Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue});
197   Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
198 
199   OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
200       b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
201 
202   // emitAssertions above already asserts that targetSize is
203   // a poistive integer.
204   uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
205 
206   assert(tileSizeInt > 0 && "target size must be non-negative");
207 
208   spec.tileSizes.push_back(targetSizeValue);
209   spec.tripCounts.push_back(tripCountValue);
210 
211   while (tileSizeInt > 1) {
212     uint64_t maxPower = llvm::bit_floor(tileSizeInt);
213     tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
214     auto constStepOp =
215         builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt);
216     tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp});
217 
218     tripCountSize = affine::makeComposedFoldedAffineApply(
219         b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp});
220 
221     // Optimization if tripCount can be determined to be zero.
222     if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) {
223       auto intAttr = cast<IntegerAttr>(attr);
224       bool isTripCountZero = intAttr.getValue().isZero();
225 
226       if (!isTripCountZero) {
227         spec.tileSizes.push_back(constStepOp);
228         spec.tripCounts.push_back(tripCountValue);
229       }
230     } else {
231       spec.tileSizes.push_back(constStepOp);
232       spec.tripCounts.push_back(tripCountValue);
233     }
234 
235     remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
236   }
237 
238   return spec;
239 }
240 
241 FailureOr<StaticMultiSizeSpecification>
242 mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
243                                           int64_t targetSize, int64_t divisor) {
244   assert(!op.hasDynamicShape() &&
245          "cannot compute static multi-tile sizes for an op with dynamic shape");
246   assert(targetSize > 0 && "target size must be non-negative");
247   assert(divisor > 0 && "divisor must be non-negative");
248   assert(dimension < op.getNumLoops() && "dimension overflow");
249 
250   StaticMultiSizeSpecification spec;
251   int64_t tripCount = op.getStaticLoopRanges()[dimension];
252   int64_t a = tripCount / divisor;
253   int64_t t = (targetSize + divisor - 1) / divisor;
254   int64_t totalTripCount = (a + t - 1) / t;
255   spec.lowTileSize = (a / totalTripCount) * divisor;
256   spec.highTileSize = spec.lowTileSize + divisor;
257   spec.highTripCount = a % totalTripCount;
258   spec.lowTripCount = totalTripCount - spec.highTripCount;
259   if (spec.lowTileSize * spec.lowTripCount +
260           spec.highTileSize * spec.highTripCount !=
261       tripCount) {
262     return failure();
263   }
264   return spec;
265 }
266 
267 FailureOr<MultiSizeSpecification>
268 mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
269                                     unsigned dimension, OpFoldResult targetSize,
270                                     OpFoldResult divisor, bool emitAssertions) {
271   // Bail out on dimension overflow.
272   if (dimension >= op.getNumLoops())
273     return failure();
274 
275   // The code below works only on values.
276   Location loc = op.getLoc();
277   ImplicitLocOpBuilder b(loc, builder);
278   if (emitAssertions) {
279     emitIsPositiveIndexAssertion(b, targetSize);
280     emitIsPositiveIndexAssertion(b, divisor);
281   }
282   Value targetSizeValue =
283       getValueOrCreateConstantIndexOp(builder, loc, targetSize);
284   Value divisorValue = getValueOrCreateConstantIndexOp(builder, loc, divisor);
285 
286   // Find the trip count of the iteration space dimension for which the tile
287   // sizes are computed.
288   SmallVector<OpFoldResult> allShapes =
289       op.createFlatListOfOperandDims(b, b.getLoc());
290   AffineMap shapesToLoops = op.getShapesToLoopsMap();
291   SmallVector<OpFoldResult> loopRanges =
292       makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
293                                                allShapes);
294   Value tripCount =
295       getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]);
296 
297   // Compute the tile sizes and the respective numbers of tiles.
298   AffineExpr s0 = b.getAffineSymbolExpr(0);
299   AffineExpr s1 = b.getAffineSymbolExpr(1);
300   AffineExpr s2 = b.getAffineSymbolExpr(2);
301   auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
302     return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
303   };
304   Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue});
305   Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue});
306   Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t});
307   Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue});
308   Value v = apply(s0 % s1, {a, d});
309   Value u = apply(s0 - s1, {d, v});
310 
311   MultiSizeSpecification spec;
312   spec.lowTileSize = s;
313   spec.highTileSize = apply(s0 + s1, {s, divisorValue});
314   spec.lowTripCount = u;
315   spec.highTripCount = v;
316 
317   // If requested, emit the check that the tile sizes are computed correctly.
318   // For example, for iteration dimension size of 15 and the target size 8 it is
319   // impossible to find two tile sizes both divisible by 8 that fully cover the
320   // original space dimension.
321   if (emitAssertions) {
322     AffineExpr s3 = builder.getAffineSymbolExpr(3);
323     Value coveredSize =
324         apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
325                                   spec.highTileSize, spec.highTripCount});
326     Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
327                                            coveredSize, tripCount);
328     b.create<cf::AssertOp>(
329         equals, builder.getStringAttr(
330                     "could not compute dynamic multi-size tile shapes"));
331   }
332 
333   return spec;
334 }
335 
336 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
337 /// than `iterationSize`.
338 static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
339                                            OpFoldResult numThreads,
340                                            OpFoldResult iterationSize) {
341   std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
342   std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
343   std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
344   if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
345     return false;
346   return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
347 }
348 
349 /// Build an `affine_max` of all the `vals`.
350 static OpFoldResult buildMax(OpBuilder &b, Location loc,
351                              ArrayRef<OpFoldResult> vals) {
352   return affine::makeComposedFoldedAffineMax(
353       b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
354       vals);
355 }
356 
357 /// Build an `affine_min` of all the `vals`.
358 static OpFoldResult buildMin(OpBuilder &b, Location loc,
359                              ArrayRef<OpFoldResult> vals) {
360   return affine::makeComposedFoldedAffineMin(
361       b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
362       vals);
363 }
364 
365 /// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given
366 /// number of threads.
367 static void calculateTileOffsetsAndSizes(
368     RewriterBase &b, Location loc, scf::ForallOp forallOp,
369     ArrayRef<OpFoldResult> numThreads, SmallVector<Range> loopRanges,
370     bool omitTileOffsetBoundsCheck,
371     std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
372     SmallVector<OpFoldResult> &tiledOffsets,
373     SmallVector<OpFoldResult> &tiledSizes) {
374   OpBuilder::InsertionGuard g(b);
375   b.setInsertionPointToStart(forallOp.getBody(0));
376 
377   SmallVector<Value> threadIds = forallOp.getInductionVars();
378   SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
379       numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
380   int64_t nLoops = loopRanges.size();
381   tiledOffsets.reserve(nLoops);
382   tiledSizes.reserve(nLoops);
383   for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
384     bool overflow = loopIdx >= numThreads.size();
385     bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
386     // Degenerate case: take the whole domain.
387     if (overflow || isZero) {
388       tiledOffsets.push_back(loopRanges[loopIdx].offset);
389       tiledSizes.push_back(loopRanges[loopIdx].size);
390       continue;
391     }
392 
393     // Tiled case: compute the offset and size.
394     AffineExpr i, j, m, n, o;
395     bindDims(b.getContext(), i, j);
396     bindSymbols(b.getContext(), m, n, o);
397     OpFoldResult size = loopRanges[loopIdx].size;
398     OpFoldResult offset = loopRanges[loopIdx].offset;
399     OpFoldResult threadId = threadIds[threadIdIdx];
400     // Symbolic fixed max size per thread.
401     // TODO: floor + 0/1 depending on case for better load-balancing.
402     OpFoldResult tileSizePerThread =
403         nominalTileSizes.has_value()
404             ? (*nominalTileSizes)[loopIdx]
405             : makeComposedFoldedAffineApply(
406                   b, loc, m.ceilDiv(n),
407                   ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]});
408 
409     // Dynamic offset shifted by threadId * maxSizePerThread.
410     OpFoldResult offsetPerThread = makeComposedFoldedAffineApply(
411         b, loc, i + j * m, {offset, threadId, tileSizePerThread});
412     // Dynamic upper-bound depending on the threadId.
413     OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
414         b, loc, i + j * m - n,
415         {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
416     if (!isConstantIntValue(residualTileSize, 0)) {
417       OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
418           b, loc, -i + m, {offsetPerThread, size});
419       tileSizePerThread =
420           buildMin(b, loc, {sizeMinusOffsetPerThread, tileSizePerThread});
421     }
422 
423     tiledOffsets.push_back(offsetPerThread);
424     // TODO: if tileSizePerThread <= 0 early exit.
425     if (!omitTileOffsetBoundsCheck &&
426         !canOmitTileOffsetInBoundsCheck(tileSizePerThread,
427                                         nonZeroNumThreads[threadIdIdx], size))
428       tileSizePerThread =
429           buildMax(b, loc, {b.getIndexAttr(0), tileSizePerThread});
430 
431     tiledSizes.push_back(tileSizePerThread);
432     ++threadIdIdx;
433   }
434 }
435 
436 template <typename LoopTy>
437 static FailureOr<TiledLinalgOp>
438 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
439                  const LinalgTilingOptions &options) {
440   OpBuilder::InsertionGuard g(b);
441 
442   auto nLoops = op.getNumLoops();
443   // Initial tile sizes may be too big, only take the first nLoops.
444   tileSizes = tileSizes.take_front(nLoops);
445 
446   if (llvm::all_of(tileSizes, [](OpFoldResult ofr) {
447         return getConstantIntValue(ofr) == static_cast<int64_t>(0);
448       })) {
449     TiledLinalgOp tiledOp;
450     tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
451     tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
452                                  tiledOp.op->result_end());
453     return tiledOp;
454   }
455 
456   // 1. Build the tiled loop ranges.
457   SmallVector<OpFoldResult> allShapeSizes =
458       op.createFlatListOfOperandDims(b, op.getLoc());
459   AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
460   if (!shapeSizesToLoopsMap)
461     return failure();
462 
463   auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
464       b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
465 
466   SmallVector<utils::IteratorType, 4> iteratorTypes;
467   for (const auto &attr : enumerate(op.getIteratorTypesArray())) {
468     if (loopIndexToRangeIndex.count(attr.index()))
469       iteratorTypes.push_back(attr.value());
470   }
471   // If interchangeVector is empty, use the identity. Build the permutation map
472   // otherwise.
473   auto invPermutationMap =
474       AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
475   if (!options.interchangeVector.empty()) {
476     // Based on the pruned iterations (due to zero tile size), recompute the
477     // interchange vector.
478     SmallVector<unsigned, 4> interchangeVector;
479     interchangeVector.reserve(options.interchangeVector.size());
480     for (auto pos : options.interchangeVector) {
481       auto it = loopIndexToRangeIndex.find(pos);
482       if (it == loopIndexToRangeIndex.end())
483         continue;
484       interchangeVector.push_back(it->second);
485     }
486     // Interchange vector is guaranteed to be a permutation,
487     // `inversePermutation` must succeed.
488     invPermutationMap = inversePermutation(
489         AffineMap::getPermutationMap(interchangeVector, b.getContext()));
490     assert(invPermutationMap);
491     SmallVector<int64_t> permutation(interchangeVector.begin(),
492                                      interchangeVector.end());
493     applyPermutationToVector(loopRanges, permutation);
494     applyPermutationToVector(iteratorTypes, permutation);
495   }
496 
497   // Handle distribution. Create a vector of the same size of loops that are to
498   // be tiled.
499   SmallVector<linalg::ProcInfo> procInfo;
500   if (options.distribution) {
501     procInfo.resize(
502         iteratorTypes.size(),
503         linalg::ProcInfo{nullptr, nullptr, linalg::DistributionMethod::None});
504     // Collect loop ranges of tiled loops, loops that are parallel.
505     SmallVector<Range> parallelLoopRanges;
506     for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
507       if (!isParallelIterator(iteratorType.value()))
508         break;
509       parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
510     }
511     auto returnedProcInfo =
512         options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges);
513     unsigned procIdIdx = 0;
514     // Update the distribution information for the loops.
515     for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
516       if (!isParallelIterator(iteratorType.value()))
517         break;
518       procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++];
519     }
520   }
521 
522   // 2. Create the tiled loops.
523   LinalgOp res = op;
524   SmallVector<Value, 4> ivs, tensorResults;
525   auto tiledLoopBodyBuilder =
526       [&](OpBuilder &builder, Location loc, ValueRange localIvs,
527           ValueRange operandValuesToUse) -> scf::ValueVector {
528     ivs.assign(localIvs.begin(), localIvs.end());
529 
530     // When an `interchangeVector` is present, it has been applied to the
531     // loop ranges and the iterator types. Apply its inverse to the
532     // resulting loop `ivs` to match the op definition.
533     SmallVector<Value, 4> interchangedIvs;
534     if (!options.interchangeVector.empty()) {
535       for (AffineExpr result : invPermutationMap.getResults())
536         interchangedIvs.push_back(
537             ivs[cast<AffineDimExpr>(result).getPosition()]);
538     } else {
539       interchangedIvs.assign(ivs.begin(), ivs.end());
540     }
541 
542     // Tile the `operandValuesToUse` that either match the `op` operands
543     // themselves or the tile loop arguments forwarding them.
544     assert(operandValuesToUse.size() ==
545                static_cast<size_t>(op->getNumOperands()) &&
546            "expect the number of operands and inputs and outputs to match");
547     SmallVector<Value> valuesToTile = operandValuesToUse;
548     SmallVector<OpFoldResult> sizeBounds =
549         makeComposedFoldedMultiResultAffineApply(b, loc, shapeSizesToLoopsMap,
550                                                  allShapeSizes);
551     SmallVector<Value> tiledOperands = makeTiledShapes(
552         b, loc, op, valuesToTile, getAsOpFoldResult(interchangedIvs), tileSizes,
553         sizeBounds,
554         /*omitPartialTileCheck=*/false);
555 
556     SmallVector<Type> resultTensorTypes =
557         getTensorOutputTypes(op, tiledOperands);
558     res = clone(b, op, resultTensorTypes, tiledOperands);
559     tensorResults =
560         insertSlicesBack(builder, loc, op, tiledOperands, res->getResults());
561     return scf::ValueVector(tensorResults.begin(), tensorResults.end());
562   };
563   GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
564                                  tiledLoopBodyBuilder, procInfo);
565 
566   // 3. Transform IndexOp results w.r.t. the tiling.
567   transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
568 
569   // 4. Gather the newly created loops and return them with the new op.
570   SmallVector<Operation *, 8> loops;
571   loops.reserve(ivs.size());
572   for (auto iv : ivs) {
573     if (isa<BlockArgument>(iv)) {
574       loops.push_back(cast<BlockArgument>(iv).getOwner()->getParentOp());
575       assert(loops.back() && "no owner found for induction variable!");
576     } else {
577       // TODO: Instead of doing this, try to recover the ops used instead of the
578       // loop.
579       loops.push_back(nullptr);
580     }
581   }
582 
583   // 5. Get the tensor results from the outermost loop if available. Otherwise
584   // use the previously captured `tensorResults`.
585   Operation *outermostLoop = nullptr;
586   for (Operation *loop : loops)
587     if ((outermostLoop = loop))
588       break;
589 
590   return TiledLinalgOp{
591       res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
592 }
593 
594 FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
595     RewriterBase &b, PartialReductionOpInterface op,
596     ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes,
597     std::optional<ArrayAttr> mapping) {
598   Location loc = op.getLoc();
599   OpBuilder::InsertionGuard g(b);
600 
601   // Ops implementing PartialReductionOpInterface are expected to implement
602   // TilingInterface.
603   // TODO: proper core mechanism to tie interfaces together.
604   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
605 
606   // Ops implementing PartialReductionOpInterface are not necessarily expected
607   // to implement TilingInterface.. This cast is unsafe atm.
608   // TODO: proper core mechanism to tie interfaces together.
609   // TODO: this function requires a pair of interfaces ..
610   auto destinationStyleOp =
611       dyn_cast<DestinationStyleOpInterface>(op.getOperation());
612   if (!destinationStyleOp)
613     return b.notifyMatchFailure(op, "not a destination style op");
614 
615   // Actually this only work for Linalg ops atm.
616   auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
617   if (!linalgOp)
618     return b.notifyMatchFailure(op, "not a linalg op");
619 
620   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
621   if (op->getNumResults() != 1)
622     return b.notifyMatchFailure(
623         op, "don't support ops with multiple results for now");
624 
625   SmallVector<utils::IteratorType> iterators =
626       tilingInterfaceOp.getLoopIteratorTypes();
627   SmallVector<unsigned> redDims;
628   linalgOp.getReductionDims(redDims);
629   if (redDims.size() != 1)
630     return b.notifyMatchFailure(
631         op, "only support ops with one reduction dimension.");
632   if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
633     return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
634                                     "many elements as number of threads");
635   int reductionDim = static_cast<int>(redDims.front());
636 
637   if (redDims.front() >= numThreads.size())
638     return b.notifyMatchFailure(
639         op, "reduction dimension must be mapped to threads");
640 
641   // 1. Create the inital tensor value.
642   FailureOr<SmallVector<Value>> maybeInitTensors =
643       op.generateInitialTensorForPartialReduction(b, loc, numThreads,
644                                                   reductionDim);
645   if (failed(maybeInitTensors))
646     return b.notifyMatchFailure(
647         op, "Failed to create inital tensors for partial reduction");
648   SmallVector<Value> &initTensors = maybeInitTensors.value();
649 
650   // Gather destination tensors.
651   SmallVector<Value> dest;
652   if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
653     return b.notifyMatchFailure(op, "failed to get destination tensors");
654 
655   Operation *tiledOp = nullptr;
656 
657   SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
658       numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
659   SmallVector<Value> materializedNonZeroNumThreads =
660       getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
661 
662   // 2. Create the ForallOp with an empty region.
663   scf::ForallOp forallOp = b.create<scf::ForallOp>(
664       loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
665       mapping);
666 
667   // 3. Calculate the tile offsets and sizes for the subsequent loop that will
668   // be nested under `forallOp`.
669   SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
670   calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain,
671                                /*omitTileOffsetBoundsCheck =*/false,
672                                /*nominalTileSizes=*/std::nullopt, tiledOffsets,
673                                tiledSizes);
674 
675   // 4b. Clone the tileable op and update its destination operands to use the
676   // output bbArgs of the ForallOp.
677   SmallVector<Value> tilingResults;
678   ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
679   {
680     // 4.a. RAII guard, inserting within forallOp, before terminator.
681     OpBuilder::InsertionGuard g(b);
682     b.setInsertionPoint(forallOp.getTerminator());
683 
684     SmallVector<Value> tiledDpsInitOperands;
685     for (Value initOperand : destinationStyleOp.getDpsInits()) {
686       auto *it = llvm::find(dest, initOperand);
687       assert(it != dest.end() && "dest operand not found in dest");
688       unsigned destNum = std::distance(dest.begin(), it);
689       SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
690       SmallVector<OpFoldResult> outOffsets(numThreads.size(),
691                                            b.getIndexAttr(0));
692       SmallVector<OpFoldResult> sizes = tiledSizes;
693       sizes[reductionDim] = b.getIndexAttr(1);
694       outOffsets[reductionDim] = forallOp.getInductionVars()[0];
695       // TODO: use SubsetExtractOpInterface once it is available.
696       tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
697           loc, cast<RankedTensorType>(initOperand.getType()),
698           destBbArgs[destNum], outOffsets, sizes, strides));
699     }
700 
701     // 4.b. Clone the op and update init operands.
702     // We cannot use a IRMapping here because it can replace
703     // different OpOperands with the same value.
704     Operation *clonedOp = b.clone(*op.getOperation());
705     b.modifyOpInPlace(clonedOp, [&]() {
706       for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
707                cast<DestinationStyleOpInterface>(clonedOp).getDpsInitsMutable(),
708                tiledDpsInitOperands)) {
709         initOperandPtr.set(tiledInitValue);
710       }
711     });
712 
713     // 5. Tile the cloned op and delete the clone.
714     if (tileSizes.empty()) {
715       FailureOr<TilingResult> tilingResult =
716           cast<TilingInterface>(clonedOp).getTiledImplementation(
717               b, tiledOffsets, tiledSizes);
718       if (failed(tilingResult))
719         return clonedOp->emitError("Failed to tile op: ");
720       if (tilingResult->tiledOps.size() != 1) {
721         return clonedOp->emitError("expected a single produced tiled op, got ")
722                << tilingResult->tiledOps.size();
723       }
724       tiledOp = tilingResult->tiledOps.front();
725       tilingResults = tilingResult->tiledValues;
726     } else {
727       LinalgTilingOptions options;
728       FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
729           b, cast<LinalgOp>(clonedOp), tileSizes, options);
730       if (failed(maybeTiled))
731         return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
732 
733       SmallVector<Value> ids = forallOp.getInductionVars();
734       mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
735                             materializedNonZeroNumThreads);
736       if (maybeTiled->loops.size() != 1) {
737         return clonedOp->emitError("expected a single produced loop");
738       }
739       tiledOp = maybeTiled->op;
740       tilingResults = maybeTiled->loops.front()->getResults();
741     }
742 
743     b.eraseOp(clonedOp);
744   }
745 
746   // 6. Insert the partial reductions back into a new tensor.
747   for (auto [index, result, bbArg] : llvm::zip(
748            llvm::seq<unsigned>(0, dest.size()), tilingResults, destBbArgs)) {
749     // 6.a. Partial subset information is inserted just before the terminator.
750     OpBuilder::InsertionGuard g(b);
751     b.setInsertionPoint(forallOp.getTerminator());
752 
753     SmallVector<OpFoldResult> resultOffsets, resultSizes;
754     if (failed(tilingInterfaceOp.getResultTilePosition(
755             b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
756       return op->emitOpError("output offsets couldn't be calculated");
757     SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
758     int64_t offIdx = 0;
759     int64_t sizeIdx = 0;
760     for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
761       if (i == reductionDim) {
762         resultOffsetsRank.push_back(forallOp.getInductionVars()[0]);
763         resultSizesRank.push_back(b.getIndexAttr(1));
764         continue;
765       }
766       resultOffsetsRank.push_back(resultOffsets[offIdx++]);
767       resultSizesRank.push_back(resultSizes[sizeIdx++]);
768     }
769     SmallVector<OpFoldResult> strides(resultSizesRank.size(),
770                                       b.getIndexAttr(1));
771 
772     // 6.b. Parallel insertions are inserted at the end of the combining
773     // terminator.
774     b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
775     b.create<tensor::ParallelInsertSliceOp>(
776         loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
777   }
778 
779   // 7. Merge the partial reductions.
780   b.setInsertionPointAfter(forallOp);
781   FailureOr<MergeResult> mergeResult =
782       op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
783   if (failed(mergeResult)) {
784     return failure();
785   }
786   b.replaceOp(op, mergeResult->replacements);
787 
788   // 8. Return.
789   ForallReductionTilingResult results;
790   results.initialValues = initTensors;
791   results.loops = forallOp;
792   results.parallelTiledOps.push_back(tiledOp);
793   results.mergeOps.append(mergeResult->mergeOps);
794   return results;
795 }
796 
797 template <typename LoopTy>
798 FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
799     RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
800   OpBuilder::InsertionGuard g(b);
801   b.setInsertionPoint(op);
802 
803   if (!options.tileSizeComputationFunction)
804     return failure();
805 
806   // Enforce the convention that "tiling by zero" skips tiling a particular
807   // dimension. This convention is significantly simpler to handle instead of
808   // adjusting affine maps to account for missing dimensions.
809   auto nLoops = op.getNumLoops();
810   SmallVector<OpFoldResult> tileSizeVector =
811       getAsOpFoldResult(options.tileSizeComputationFunction(b, op));
812   if (tileSizeVector.size() < nLoops) {
813     tileSizeVector.append(nLoops - tileSizeVector.size(), b.getIndexAttr(0));
814   }
815 
816   return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
817 }
818 
819 FailureOr<TiledLinalgOp>
820 mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op,
821                            const LinalgTilingOptions &options) {
822   switch (options.loopType) {
823   case LinalgTilingLoopType::Loops:
824     return tileLinalgOpImpl<scf::ForOp>(b, op, options);
825   case LinalgTilingLoopType::ParallelLoops:
826     return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
827   default:;
828   }
829   return failure();
830 }
831 
832 namespace {
833 /// Helper classes for type list expansion.
834 template <typename... OpTypes>
835 class CanonicalizationPatternList;
836 
837 template <>
838 class CanonicalizationPatternList<> {
839 public:
840   static void insert(RewritePatternSet &patterns) {}
841 };
842 
843 template <typename OpTy, typename... OpTypes>
844 class CanonicalizationPatternList<OpTy, OpTypes...> {
845 public:
846   static void insert(RewritePatternSet &patterns) {
847     OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
848     CanonicalizationPatternList<OpTypes...>::insert(patterns);
849   }
850 };
851 } // namespace
852 
853 RewritePatternSet
854 mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
855   RewritePatternSet patterns(ctx);
856   populateLinalgTilingCanonicalizationPatterns(patterns);
857   return patterns;
858 }
859 
860 void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
861     RewritePatternSet &patterns) {
862   auto *ctx = patterns.getContext();
863   affine::AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
864   affine::AffineForOp::getCanonicalizationPatterns(patterns, ctx);
865   affine::AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
866   affine::AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
867   arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
868 
869   memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
870   memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
871 
872   scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
873   scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
874 
875   tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
876   tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx);
877   tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
878   tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
879   tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
880   ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
881 
882   CanonicalizationPatternList<
883 #define GET_OP_LIST
884 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
885       >::insert(patterns);
886 }
887