xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (revision 4f279a570110e3d688356a327637c57071f4b13b)
1b6281940SNicolas Vasilache //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2b6281940SNicolas Vasilache //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b6281940SNicolas Vasilache //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
8b6281940SNicolas Vasilache //
9b6281940SNicolas Vasilache // This file implements the linalg dialect Tiling pass.
10b6281940SNicolas Vasilache //
11b6281940SNicolas Vasilache //===----------------------------------------------------------------------===//
12b6281940SNicolas Vasilache 
1367d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h"
141fc096afSMehdi Amini 
15e99fae89SAlex Zinenko #include "mlir/Dialect/Affine/IR/AffineOps.h"
16f7fda6baSThomas Raoux #include "mlir/Dialect/Affine/LoopUtils.h"
17abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h"
183963b4d0SAlex Zinenko #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1967d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
20b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
21307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h"
238b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
24129d6e55SSean Silva #include "mlir/Dialect/Tensor/IR/Tensor.h"
25f71f9958SDiego Caballero #include "mlir/Dialect/Utils/IndexingUtils.h"
26b6281940SNicolas Vasilache #include "mlir/IR/AffineExpr.h"
27b6281940SNicolas Vasilache #include "mlir/IR/AffineMap.h"
2806ca5c81SNicolas Vasilache #include "mlir/IR/BuiltinOps.h"
2906ca5c81SNicolas Vasilache #include "mlir/IR/ValueRange.h"
30b6281940SNicolas Vasilache #include "mlir/Transforms/FoldUtils.h"
31b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3206ca5c81SNicolas Vasilache #include "llvm/ADT/STLExtras.h"
33b6281940SNicolas Vasilache #include "llvm/Support/CommandLine.h"
3467d0d7acSMichele Scuttari #include <utility>
3567d0d7acSMichele Scuttari 
3667d0d7acSMichele Scuttari namespace mlir {
3767d0d7acSMichele Scuttari #define GEN_PASS_DEF_LINALGTILINGPASS
3867d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h.inc"
3967d0d7acSMichele Scuttari } // namespace mlir
40b6281940SNicolas Vasilache 
41b6281940SNicolas Vasilache using namespace mlir;
424c48f016SMatthias Springer using namespace mlir::affine;
43b6281940SNicolas Vasilache using namespace mlir::linalg;
44c25b20c0SAlex Zinenko using namespace mlir::scf;
45b6281940SNicolas Vasilache 
46039b969bSMichele Scuttari #define DEBUG_TYPE "linalg-tiling"
47039b969bSMichele Scuttari 
48c9620389SAlexander Belyaev std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
49c9620389SAlexander Belyaev mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
50e99fae89SAlex Zinenko                                   ArrayRef<OpFoldResult> allShapeSizes,
51e99fae89SAlex Zinenko                                   ArrayRef<OpFoldResult> allTileSizes) {
52b6281940SNicolas Vasilache   assert(allTileSizes.size() == map.getNumResults());
53a3adcba6SNicolas Vasilache   // Apply `map` to get shape sizes in loop order.
54e99fae89SAlex Zinenko   SmallVector<OpFoldResult> shapeSizes =
55e99fae89SAlex Zinenko       makeComposedFoldedMultiResultAffineApply(b, loc, map, allShapeSizes);
565262865aSKazu Hirata   SmallVector<OpFoldResult> tileSizes(allTileSizes);
57b6281940SNicolas Vasilache 
58b6281940SNicolas Vasilache   // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
59bae8a7a7SAlexander Belyaev   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
60bae8a7a7SAlexander Belyaev   for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
61cb7bda2aSMatthias Springer     if (getConstantIntValue(tileSizes[idx - zerosCount]) ==
62cb7bda2aSMatthias Springer         static_cast<int64_t>(0)) {
63a3adcba6SNicolas Vasilache       shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
64bae8a7a7SAlexander Belyaev       tileSizes.erase(tileSizes.begin() + idx - zerosCount);
65bae8a7a7SAlexander Belyaev       ++zerosCount;
66bae8a7a7SAlexander Belyaev       continue;
67b6281940SNicolas Vasilache     }
68bae8a7a7SAlexander Belyaev     loopIndexToRangeIndex[idx] = idx - zerosCount;
69b6281940SNicolas Vasilache   }
70b6281940SNicolas Vasilache 
71b6281940SNicolas Vasilache   // Create a new range with the applied tile sizes.
72e3de249aSNicolas Vasilache   SmallVector<Range, 4> res;
73004a3d4fSNicolas Vasilache   for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
7470e99f38SAlex Zinenko     res.push_back(Range{b.getIndexAttr(0), shapeSizes[idx], tileSizes[idx]});
75bae8a7a7SAlexander Belyaev   return std::make_tuple(res, loopIndexToRangeIndex);
76b6281940SNicolas Vasilache }
77b6281940SNicolas Vasilache 
78c9620389SAlexander Belyaev void mlir::linalg::transformIndexOps(
79c9620389SAlexander Belyaev     RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
808ea5d190STobias Gysi     const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
8190b7817eSTobias Gysi   SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
828c258fdaSJakub Kuderski   for (auto en : enumerate(allIvs)) {
8390b7817eSTobias Gysi     auto rangeIndex = loopIndexToRangeIndex.find(en.index());
848ea5d190STobias Gysi     if (rangeIndex == loopIndexToRangeIndex.end())
858ea5d190STobias Gysi       continue;
8690b7817eSTobias Gysi     en.value() = ivs[rangeIndex->second];
878ea5d190STobias Gysi   }
88e99fae89SAlex Zinenko   offsetIndices(b, op, getAsOpFoldResult(allIvs));
898ea5d190STobias Gysi }
908ea5d190STobias Gysi 
913963b4d0SAlex Zinenko /// Asserts that the given index-typed value is strictly positive. If the value
923963b4d0SAlex Zinenko /// is an attribute, asserts at compile time, otherwise emits an assertion
933963b4d0SAlex Zinenko /// checked at runtime.
943963b4d0SAlex Zinenko static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
953963b4d0SAlex Zinenko                                          OpFoldResult value) {
9668f58812STres Popp   if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
975550c821STres Popp     assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
983963b4d0SAlex Zinenko            "expected strictly positive tile size and divisor");
993963b4d0SAlex Zinenko     return;
1003963b4d0SAlex Zinenko   }
1013963b4d0SAlex Zinenko 
1023963b4d0SAlex Zinenko   Value zero = b.create<arith::ConstantIndexOp>(0);
1033963b4d0SAlex Zinenko   Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt,
104*4f279a57SKazu Hirata                                             cast<Value>(value), zero);
1053963b4d0SAlex Zinenko   b.create<cf::AssertOp>(
1063963b4d0SAlex Zinenko       condition,
1073963b4d0SAlex Zinenko       b.getStringAttr("expected strictly positive tile size and divisor"));
1083963b4d0SAlex Zinenko }
1093963b4d0SAlex Zinenko 
110a9efcbf4Smuneebkhan85 FailureOr<StaticContinuousTileSizeSpecification>
111a9efcbf4Smuneebkhan85 mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
112a9efcbf4Smuneebkhan85                                                unsigned dimension,
113a9efcbf4Smuneebkhan85                                                unsigned targetSize) {
114a9efcbf4Smuneebkhan85 
115a9efcbf4Smuneebkhan85   assert(!op.hasDynamicShape() &&
116a9efcbf4Smuneebkhan85          "cannot compute static multi-tile sizes for an op with dynamic shape");
117a9efcbf4Smuneebkhan85   assert(targetSize > 0 && "target size must be non-negative");
118a9efcbf4Smuneebkhan85   assert(dimension < op.getNumLoops() && "dimension overflow");
119a9efcbf4Smuneebkhan85 
120a9efcbf4Smuneebkhan85   StaticContinuousTileSizeSpecification spec;
121a9efcbf4Smuneebkhan85   int64_t loopRange = op.getStaticLoopRanges()[dimension];
122a9efcbf4Smuneebkhan85   int64_t tripCount = loopRange / targetSize;
123a9efcbf4Smuneebkhan85 
124a9efcbf4Smuneebkhan85   unsigned tileSize = targetSize;
125a9efcbf4Smuneebkhan85 
126a9efcbf4Smuneebkhan85   spec.tileSizes.push_back(tileSize);
127a9efcbf4Smuneebkhan85   spec.tripCounts.push_back(tripCount);
128a9efcbf4Smuneebkhan85 
129a9efcbf4Smuneebkhan85   int64_t remainderChunk = loopRange % targetSize;
130a9efcbf4Smuneebkhan85 
131a9efcbf4Smuneebkhan85   while (tileSize > 1 && remainderChunk != 0) {
132a9efcbf4Smuneebkhan85 
133a9efcbf4Smuneebkhan85     uint64_t maxPower = llvm::bit_floor(tileSize);
134a9efcbf4Smuneebkhan85     tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
135a9efcbf4Smuneebkhan85 
136a9efcbf4Smuneebkhan85     tripCount = remainderChunk / tileSize;
137a9efcbf4Smuneebkhan85 
138a9efcbf4Smuneebkhan85     if (tripCount > 0) {
139a9efcbf4Smuneebkhan85       spec.tileSizes.push_back(tileSize);
140a9efcbf4Smuneebkhan85       spec.tripCounts.push_back(tripCount);
141a9efcbf4Smuneebkhan85     }
142a9efcbf4Smuneebkhan85 
143a9efcbf4Smuneebkhan85     remainderChunk = remainderChunk % tileSize;
144a9efcbf4Smuneebkhan85   }
145a9efcbf4Smuneebkhan85 
146a9efcbf4Smuneebkhan85   auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
147a9efcbf4Smuneebkhan85                             SmallVector<int64_t> tripCounts,
148a9efcbf4Smuneebkhan85                             int64_t range) -> bool {
149a9efcbf4Smuneebkhan85     int64_t computedRange = 0;
150a9efcbf4Smuneebkhan85     for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
151a9efcbf4Smuneebkhan85       computedRange += tileSize * tripCount;
152a9efcbf4Smuneebkhan85     return range == computedRange;
153a9efcbf4Smuneebkhan85   };
154a9efcbf4Smuneebkhan85 
155a9efcbf4Smuneebkhan85   if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
156a9efcbf4Smuneebkhan85     return failure();
157a9efcbf4Smuneebkhan85 
158a9efcbf4Smuneebkhan85   return spec;
159a9efcbf4Smuneebkhan85 }
160a9efcbf4Smuneebkhan85 
161a9efcbf4Smuneebkhan85 FailureOr<ContinuousTileSizeSpecification>
162a9efcbf4Smuneebkhan85 mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
163a9efcbf4Smuneebkhan85                                          unsigned dimension,
164a9efcbf4Smuneebkhan85                                          OpFoldResult targetSize,
165a9efcbf4Smuneebkhan85                                          bool emitAssertions) {
166a9efcbf4Smuneebkhan85 
167a9efcbf4Smuneebkhan85   SmallVector<Range> loopRanges = op.getIterationDomain(builder);
168a9efcbf4Smuneebkhan85   unsigned numLoops = loopRanges.size();
169a9efcbf4Smuneebkhan85 
170a9efcbf4Smuneebkhan85   // Bail out on dimension overflow.
171a9efcbf4Smuneebkhan85   if (dimension >= numLoops)
172a9efcbf4Smuneebkhan85     return failure();
173a9efcbf4Smuneebkhan85 
174a9efcbf4Smuneebkhan85   // The code below works only on values.
175a9efcbf4Smuneebkhan85   Location loc = op->getLoc();
176a9efcbf4Smuneebkhan85   ImplicitLocOpBuilder b(loc, builder);
177a9efcbf4Smuneebkhan85   if (emitAssertions) {
178a9efcbf4Smuneebkhan85     emitIsPositiveIndexAssertion(b, targetSize);
179a9efcbf4Smuneebkhan85   }
180a9efcbf4Smuneebkhan85   Value targetSizeValue =
181a9efcbf4Smuneebkhan85       getValueOrCreateConstantIndexOp(builder, loc, targetSize);
182a9efcbf4Smuneebkhan85 
183a9efcbf4Smuneebkhan85   // Find the trip count of the iteration space dimension for which the tile
184a9efcbf4Smuneebkhan85   // sizes are computed.
185a9efcbf4Smuneebkhan85   Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
186a9efcbf4Smuneebkhan85                                                     loopRanges[dimension].size);
187a9efcbf4Smuneebkhan85   ContinuousTileSizeSpecification spec;
188a9efcbf4Smuneebkhan85 
189a9efcbf4Smuneebkhan85   // Compute the tile sizes and the respective numbers of tiles.
190a9efcbf4Smuneebkhan85   AffineExpr s0 = b.getAffineSymbolExpr(0);
191a9efcbf4Smuneebkhan85   AffineExpr s1 = b.getAffineSymbolExpr(1);
192a9efcbf4Smuneebkhan85   auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
193a9efcbf4Smuneebkhan85     return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
194a9efcbf4Smuneebkhan85   };
195a9efcbf4Smuneebkhan85 
196a9efcbf4Smuneebkhan85   Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue});
197a9efcbf4Smuneebkhan85   Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
198a9efcbf4Smuneebkhan85 
199a9efcbf4Smuneebkhan85   OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
200a9efcbf4Smuneebkhan85       b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
201a9efcbf4Smuneebkhan85 
202a9efcbf4Smuneebkhan85   // emitAssertions above already asserts that targetSize is
203a9efcbf4Smuneebkhan85   // a poistive integer.
204a9efcbf4Smuneebkhan85   uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
205a9efcbf4Smuneebkhan85 
206a9efcbf4Smuneebkhan85   assert(tileSizeInt > 0 && "target size must be non-negative");
207a9efcbf4Smuneebkhan85 
208a9efcbf4Smuneebkhan85   spec.tileSizes.push_back(targetSizeValue);
209a9efcbf4Smuneebkhan85   spec.tripCounts.push_back(tripCountValue);
210a9efcbf4Smuneebkhan85 
211a9efcbf4Smuneebkhan85   while (tileSizeInt > 1) {
212a9efcbf4Smuneebkhan85     uint64_t maxPower = llvm::bit_floor(tileSizeInt);
213a9efcbf4Smuneebkhan85     tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
214a9efcbf4Smuneebkhan85     auto constStepOp =
215a9efcbf4Smuneebkhan85         builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt);
216a9efcbf4Smuneebkhan85     tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp});
217a9efcbf4Smuneebkhan85 
218a9efcbf4Smuneebkhan85     tripCountSize = affine::makeComposedFoldedAffineApply(
219a9efcbf4Smuneebkhan85         b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp});
220a9efcbf4Smuneebkhan85 
221a9efcbf4Smuneebkhan85     // Optimization if tripCount can be determined to be zero.
222a9efcbf4Smuneebkhan85     if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) {
223a9efcbf4Smuneebkhan85       auto intAttr = cast<IntegerAttr>(attr);
224a9efcbf4Smuneebkhan85       bool isTripCountZero = intAttr.getValue().isZero();
225a9efcbf4Smuneebkhan85 
226a9efcbf4Smuneebkhan85       if (!isTripCountZero) {
227a9efcbf4Smuneebkhan85         spec.tileSizes.push_back(constStepOp);
228a9efcbf4Smuneebkhan85         spec.tripCounts.push_back(tripCountValue);
229a9efcbf4Smuneebkhan85       }
230a9efcbf4Smuneebkhan85     } else {
231a9efcbf4Smuneebkhan85       spec.tileSizes.push_back(constStepOp);
232a9efcbf4Smuneebkhan85       spec.tripCounts.push_back(tripCountValue);
233a9efcbf4Smuneebkhan85     }
234a9efcbf4Smuneebkhan85 
235a9efcbf4Smuneebkhan85     remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
236a9efcbf4Smuneebkhan85   }
237a9efcbf4Smuneebkhan85 
238a9efcbf4Smuneebkhan85   return spec;
239a9efcbf4Smuneebkhan85 }
240a9efcbf4Smuneebkhan85 
24188c5027bSAlex Zinenko FailureOr<StaticMultiSizeSpecification>
24288c5027bSAlex Zinenko mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
24388c5027bSAlex Zinenko                                           int64_t targetSize, int64_t divisor) {
24488c5027bSAlex Zinenko   assert(!op.hasDynamicShape() &&
24588c5027bSAlex Zinenko          "cannot compute static multi-tile sizes for an op with dynamic shape");
24688c5027bSAlex Zinenko   assert(targetSize > 0 && "target size must be non-negative");
24788c5027bSAlex Zinenko   assert(divisor > 0 && "divisor must be non-negative");
24888c5027bSAlex Zinenko   assert(dimension < op.getNumLoops() && "dimension overflow");
24988c5027bSAlex Zinenko 
25088c5027bSAlex Zinenko   StaticMultiSizeSpecification spec;
25188c5027bSAlex Zinenko   int64_t tripCount = op.getStaticLoopRanges()[dimension];
25288c5027bSAlex Zinenko   int64_t a = tripCount / divisor;
25388c5027bSAlex Zinenko   int64_t t = (targetSize + divisor - 1) / divisor;
25488c5027bSAlex Zinenko   int64_t totalTripCount = (a + t - 1) / t;
25588c5027bSAlex Zinenko   spec.lowTileSize = (a / totalTripCount) * divisor;
25688c5027bSAlex Zinenko   spec.highTileSize = spec.lowTileSize + divisor;
25788c5027bSAlex Zinenko   spec.highTripCount = a % totalTripCount;
25888c5027bSAlex Zinenko   spec.lowTripCount = totalTripCount - spec.highTripCount;
25988c5027bSAlex Zinenko   if (spec.lowTileSize * spec.lowTripCount +
26088c5027bSAlex Zinenko           spec.highTileSize * spec.highTripCount !=
26188c5027bSAlex Zinenko       tripCount) {
26288c5027bSAlex Zinenko     return failure();
26388c5027bSAlex Zinenko   }
26488c5027bSAlex Zinenko   return spec;
26588c5027bSAlex Zinenko }
26688c5027bSAlex Zinenko 
2673963b4d0SAlex Zinenko FailureOr<MultiSizeSpecification>
2683963b4d0SAlex Zinenko mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
2693963b4d0SAlex Zinenko                                     unsigned dimension, OpFoldResult targetSize,
2703963b4d0SAlex Zinenko                                     OpFoldResult divisor, bool emitAssertions) {
2713963b4d0SAlex Zinenko   // Bail out on dimension overflow.
2723963b4d0SAlex Zinenko   if (dimension >= op.getNumLoops())
2733963b4d0SAlex Zinenko     return failure();
2743963b4d0SAlex Zinenko 
2753963b4d0SAlex Zinenko   // The code below works only on values.
2764bf84e43SAlexander Belyaev   Location loc = op.getLoc();
2774bf84e43SAlexander Belyaev   ImplicitLocOpBuilder b(loc, builder);
2783963b4d0SAlex Zinenko   if (emitAssertions) {
2793963b4d0SAlex Zinenko     emitIsPositiveIndexAssertion(b, targetSize);
2803963b4d0SAlex Zinenko     emitIsPositiveIndexAssertion(b, divisor);
2813963b4d0SAlex Zinenko   }
2824bf84e43SAlexander Belyaev   Value targetSizeValue =
2834bf84e43SAlexander Belyaev       getValueOrCreateConstantIndexOp(builder, loc, targetSize);
2844bf84e43SAlexander Belyaev   Value divisorValue = getValueOrCreateConstantIndexOp(builder, loc, divisor);
2853963b4d0SAlex Zinenko 
2863963b4d0SAlex Zinenko   // Find the trip count of the iteration space dimension for which the tile
2873963b4d0SAlex Zinenko   // sizes are computed.
288e99fae89SAlex Zinenko   SmallVector<OpFoldResult> allShapes =
2893963b4d0SAlex Zinenko       op.createFlatListOfOperandDims(b, b.getLoc());
2903963b4d0SAlex Zinenko   AffineMap shapesToLoops = op.getShapesToLoopsMap();
291e99fae89SAlex Zinenko   SmallVector<OpFoldResult> loopRanges =
29226821f75SAlex Zinenko       makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
29326821f75SAlex Zinenko                                                allShapes);
294e99fae89SAlex Zinenko   Value tripCount =
2954bf84e43SAlexander Belyaev       getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]);
2963963b4d0SAlex Zinenko 
2973963b4d0SAlex Zinenko   // Compute the tile sizes and the respective numbers of tiles.
2983963b4d0SAlex Zinenko   AffineExpr s0 = b.getAffineSymbolExpr(0);
2993963b4d0SAlex Zinenko   AffineExpr s1 = b.getAffineSymbolExpr(1);
3003963b4d0SAlex Zinenko   AffineExpr s2 = b.getAffineSymbolExpr(2);
301efc290ceSMatthias Springer   auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
302efc290ceSMatthias Springer     return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
3033963b4d0SAlex Zinenko   };
3043963b4d0SAlex Zinenko   Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue});
3053963b4d0SAlex Zinenko   Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue});
3063963b4d0SAlex Zinenko   Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t});
3073963b4d0SAlex Zinenko   Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue});
3083963b4d0SAlex Zinenko   Value v = apply(s0 % s1, {a, d});
3093963b4d0SAlex Zinenko   Value u = apply(s0 - s1, {d, v});
3103963b4d0SAlex Zinenko 
3113963b4d0SAlex Zinenko   MultiSizeSpecification spec;
3123963b4d0SAlex Zinenko   spec.lowTileSize = s;
3133963b4d0SAlex Zinenko   spec.highTileSize = apply(s0 + s1, {s, divisorValue});
3143963b4d0SAlex Zinenko   spec.lowTripCount = u;
3153963b4d0SAlex Zinenko   spec.highTripCount = v;
3163963b4d0SAlex Zinenko 
3173963b4d0SAlex Zinenko   // If requested, emit the check that the tile sizes are computed correctly.
3183963b4d0SAlex Zinenko   // For example, for iteration dimension size of 15 and the target size 8 it is
3193963b4d0SAlex Zinenko   // impossible to find two tile sizes both divisible by 8 that fully cover the
3203963b4d0SAlex Zinenko   // original space dimension.
3213963b4d0SAlex Zinenko   if (emitAssertions) {
3223963b4d0SAlex Zinenko     AffineExpr s3 = builder.getAffineSymbolExpr(3);
3233963b4d0SAlex Zinenko     Value coveredSize =
3243963b4d0SAlex Zinenko         apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
3253963b4d0SAlex Zinenko                                   spec.highTileSize, spec.highTripCount});
3263963b4d0SAlex Zinenko     Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
3273963b4d0SAlex Zinenko                                            coveredSize, tripCount);
3283963b4d0SAlex Zinenko     b.create<cf::AssertOp>(
3293963b4d0SAlex Zinenko         equals, builder.getStringAttr(
3303963b4d0SAlex Zinenko                     "could not compute dynamic multi-size tile shapes"));
3313963b4d0SAlex Zinenko   }
3323963b4d0SAlex Zinenko 
3333963b4d0SAlex Zinenko   return spec;
3343963b4d0SAlex Zinenko }
3353963b4d0SAlex Zinenko 
336297ba167SChristopher Bate /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
337297ba167SChristopher Bate /// than `iterationSize`.
338297ba167SChristopher Bate static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
339297ba167SChristopher Bate                                            OpFoldResult numThreads,
340297ba167SChristopher Bate                                            OpFoldResult iterationSize) {
34122426110SRamkumar Ramachandra   std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
34222426110SRamkumar Ramachandra   std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
34322426110SRamkumar Ramachandra   std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
344297ba167SChristopher Bate   if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
345297ba167SChristopher Bate     return false;
346297ba167SChristopher Bate   return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
34718b92c66SNicolas Vasilache }
34818b92c66SNicolas Vasilache 
349e99fae89SAlex Zinenko /// Build an `affine_max` of all the `vals`.
350e99fae89SAlex Zinenko static OpFoldResult buildMax(OpBuilder &b, Location loc,
351e99fae89SAlex Zinenko                              ArrayRef<OpFoldResult> vals) {
3524c48f016SMatthias Springer   return affine::makeComposedFoldedAffineMax(
35326821f75SAlex Zinenko       b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
35426821f75SAlex Zinenko       vals);
355e99fae89SAlex Zinenko }
356e99fae89SAlex Zinenko 
357e99fae89SAlex Zinenko /// Build an `affine_min` of all the `vals`.
358e99fae89SAlex Zinenko static OpFoldResult buildMin(OpBuilder &b, Location loc,
359e99fae89SAlex Zinenko                              ArrayRef<OpFoldResult> vals) {
3604c48f016SMatthias Springer   return affine::makeComposedFoldedAffineMin(
36126821f75SAlex Zinenko       b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
36226821f75SAlex Zinenko       vals);
363e99fae89SAlex Zinenko }
364e99fae89SAlex Zinenko 
36599833cd8SThomas Raoux /// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given
36699833cd8SThomas Raoux /// number of threads.
36799833cd8SThomas Raoux static void calculateTileOffsetsAndSizes(
368eb2f946eSAlexander Belyaev     RewriterBase &b, Location loc, scf::ForallOp forallOp,
36999833cd8SThomas Raoux     ArrayRef<OpFoldResult> numThreads, SmallVector<Range> loopRanges,
37099833cd8SThomas Raoux     bool omitTileOffsetBoundsCheck,
37122426110SRamkumar Ramachandra     std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
37299833cd8SThomas Raoux     SmallVector<OpFoldResult> &tiledOffsets,
37399833cd8SThomas Raoux     SmallVector<OpFoldResult> &tiledSizes) {
37406ca5c81SNicolas Vasilache   OpBuilder::InsertionGuard g(b);
375eb2f946eSAlexander Belyaev   b.setInsertionPointToStart(forallOp.getBody(0));
37606ca5c81SNicolas Vasilache 
3776b4c1228Ssrcarroll   SmallVector<Value> threadIds = forallOp.getInductionVars();
378f4d75863SJakub Kuderski   SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
379f4d75863SJakub Kuderski       numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
38018b92c66SNicolas Vasilache   int64_t nLoops = loopRanges.size();
38118b92c66SNicolas Vasilache   tiledOffsets.reserve(nLoops);
38218b92c66SNicolas Vasilache   tiledSizes.reserve(nLoops);
383297ba167SChristopher Bate   for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
38418b92c66SNicolas Vasilache     bool overflow = loopIdx >= numThreads.size();
38518b92c66SNicolas Vasilache     bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
38618b92c66SNicolas Vasilache     // Degenerate case: take the whole domain.
38718b92c66SNicolas Vasilache     if (overflow || isZero) {
38818b92c66SNicolas Vasilache       tiledOffsets.push_back(loopRanges[loopIdx].offset);
38918b92c66SNicolas Vasilache       tiledSizes.push_back(loopRanges[loopIdx].size);
39018b92c66SNicolas Vasilache       continue;
39118b92c66SNicolas Vasilache     }
39218b92c66SNicolas Vasilache 
39318b92c66SNicolas Vasilache     // Tiled case: compute the offset and size.
394a7a892aeSMehdi Amini     AffineExpr i, j, m, n, o;
39518b92c66SNicolas Vasilache     bindDims(b.getContext(), i, j);
396a7a892aeSMehdi Amini     bindSymbols(b.getContext(), m, n, o);
397e99fae89SAlex Zinenko     OpFoldResult size = loopRanges[loopIdx].size;
398e99fae89SAlex Zinenko     OpFoldResult offset = loopRanges[loopIdx].offset;
399e99fae89SAlex Zinenko     OpFoldResult threadId = threadIds[threadIdIdx];
40018b92c66SNicolas Vasilache     // Symbolic fixed max size per thread.
40118b92c66SNicolas Vasilache     // TODO: floor + 0/1 depending on case for better load-balancing.
402297ba167SChristopher Bate     OpFoldResult tileSizePerThread =
4036fa6901bSKazu Hirata         nominalTileSizes.has_value()
404297ba167SChristopher Bate             ? (*nominalTileSizes)[loopIdx]
405297ba167SChristopher Bate             : makeComposedFoldedAffineApply(
406a7a892aeSMehdi Amini                   b, loc, m.ceilDiv(n),
407297ba167SChristopher Bate                   ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]});
408297ba167SChristopher Bate 
40918b92c66SNicolas Vasilache     // Dynamic offset shifted by threadId * maxSizePerThread.
410297ba167SChristopher Bate     OpFoldResult offsetPerThread = makeComposedFoldedAffineApply(
411a7a892aeSMehdi Amini         b, loc, i + j * m, {offset, threadId, tileSizePerThread});
41218b92c66SNicolas Vasilache     // Dynamic upper-bound depending on the threadId.
413297ba167SChristopher Bate     OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
414a7a892aeSMehdi Amini         b, loc, i + j * m - n,
415297ba167SChristopher Bate         {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
416297ba167SChristopher Bate     if (!isConstantIntValue(residualTileSize, 0)) {
417297ba167SChristopher Bate       OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
418a7a892aeSMehdi Amini           b, loc, -i + m, {offsetPerThread, size});
419e99fae89SAlex Zinenko       tileSizePerThread =
420e99fae89SAlex Zinenko           buildMin(b, loc, {sizeMinusOffsetPerThread, tileSizePerThread});
421297ba167SChristopher Bate     }
422297ba167SChristopher Bate 
42318b92c66SNicolas Vasilache     tiledOffsets.push_back(offsetPerThread);
42418b92c66SNicolas Vasilache     // TODO: if tileSizePerThread <= 0 early exit.
425297ba167SChristopher Bate     if (!omitTileOffsetBoundsCheck &&
426297ba167SChristopher Bate         !canOmitTileOffsetInBoundsCheck(tileSizePerThread,
427297ba167SChristopher Bate                                         nonZeroNumThreads[threadIdIdx], size))
428e99fae89SAlex Zinenko       tileSizePerThread =
429e99fae89SAlex Zinenko           buildMax(b, loc, {b.getIndexAttr(0), tileSizePerThread});
430297ba167SChristopher Bate 
431297ba167SChristopher Bate     tiledSizes.push_back(tileSizePerThread);
43218b92c66SNicolas Vasilache     ++threadIdIdx;
43318b92c66SNicolas Vasilache   }
43499833cd8SThomas Raoux }
43599833cd8SThomas Raoux 
4360da755dfSAlexander Belyaev template <typename LoopTy>
437489fec27SNicolas Vasilache static FailureOr<TiledLinalgOp>
438e99fae89SAlex Zinenko tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
439c694588fSMaheshRavishankar                  const LinalgTilingOptions &options) {
44006ca5c81SNicolas Vasilache   OpBuilder::InsertionGuard g(b);
44106ca5c81SNicolas Vasilache 
442004a3d4fSNicolas Vasilache   auto nLoops = op.getNumLoops();
443004a3d4fSNicolas Vasilache   // Initial tile sizes may be too big, only take the first nLoops.
444004a3d4fSNicolas Vasilache   tileSizes = tileSizes.take_front(nLoops);
445004a3d4fSNicolas Vasilache 
446cb7bda2aSMatthias Springer   if (llvm::all_of(tileSizes, [](OpFoldResult ofr) {
447cb7bda2aSMatthias Springer         return getConstantIntValue(ofr) == static_cast<int64_t>(0);
448cb7bda2aSMatthias Springer       })) {
449526dfe3fSMaheshRavishankar     TiledLinalgOp tiledOp;
450526dfe3fSMaheshRavishankar     tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
451526dfe3fSMaheshRavishankar     tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
452526dfe3fSMaheshRavishankar                                  tiledOp.op->result_end());
453526dfe3fSMaheshRavishankar     return tiledOp;
454526dfe3fSMaheshRavishankar   }
455f60bbb6cSJose Ignacio Gomez 
456c694588fSMaheshRavishankar   // 1. Build the tiled loop ranges.
457e99fae89SAlex Zinenko   SmallVector<OpFoldResult> allShapeSizes =
458e99fae89SAlex Zinenko       op.createFlatListOfOperandDims(b, op.getLoc());
45901c44185SNicolas Vasilache   AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
460a3adcba6SNicolas Vasilache   if (!shapeSizesToLoopsMap)
461489fec27SNicolas Vasilache     return failure();
462bae8a7a7SAlexander Belyaev 
4639fa59e76SBenjamin Kramer   auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
464a3adcba6SNicolas Vasilache       b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
46542444d0cSMaheshRavishankar 
466e6598b05SOleg Shyshkov   SmallVector<utils::IteratorType, 4> iteratorTypes;
467c54bc8bdSOleg Shyshkov   for (const auto &attr : enumerate(op.getIteratorTypesArray())) {
468c694588fSMaheshRavishankar     if (loopIndexToRangeIndex.count(attr.index()))
469c694588fSMaheshRavishankar       iteratorTypes.push_back(attr.value());
470c694588fSMaheshRavishankar   }
471c694588fSMaheshRavishankar   // If interchangeVector is empty, use the identity. Build the permutation map
472c694588fSMaheshRavishankar   // otherwise.
473c694588fSMaheshRavishankar   auto invPermutationMap =
474c694588fSMaheshRavishankar       AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
475c694588fSMaheshRavishankar   if (!options.interchangeVector.empty()) {
476c694588fSMaheshRavishankar     // Based on the pruned iterations (due to zero tile size), recompute the
477c694588fSMaheshRavishankar     // interchange vector.
478c694588fSMaheshRavishankar     SmallVector<unsigned, 4> interchangeVector;
479c694588fSMaheshRavishankar     interchangeVector.reserve(options.interchangeVector.size());
480c694588fSMaheshRavishankar     for (auto pos : options.interchangeVector) {
481c694588fSMaheshRavishankar       auto it = loopIndexToRangeIndex.find(pos);
482c694588fSMaheshRavishankar       if (it == loopIndexToRangeIndex.end())
483c694588fSMaheshRavishankar         continue;
484c694588fSMaheshRavishankar       interchangeVector.push_back(it->second);
485c694588fSMaheshRavishankar     }
48601c44185SNicolas Vasilache     // Interchange vector is guaranteed to be a permutation,
48701c44185SNicolas Vasilache     // `inversePermutation` must succeed.
488c694588fSMaheshRavishankar     invPermutationMap = inversePermutation(
489c694588fSMaheshRavishankar         AffineMap::getPermutationMap(interchangeVector, b.getContext()));
49001c44185SNicolas Vasilache     assert(invPermutationMap);
4919072f1b5STobias Gysi     SmallVector<int64_t> permutation(interchangeVector.begin(),
4929072f1b5STobias Gysi                                      interchangeVector.end());
4939072f1b5STobias Gysi     applyPermutationToVector(loopRanges, permutation);
4949072f1b5STobias Gysi     applyPermutationToVector(iteratorTypes, permutation);
495c694588fSMaheshRavishankar   }
496b6281940SNicolas Vasilache 
497f365e85cSMahesh Ravishankar   // Handle distribution. Create a vector of the same size of loops that are to
498f365e85cSMahesh Ravishankar   // be tiled.
499f365e85cSMahesh Ravishankar   SmallVector<linalg::ProcInfo> procInfo;
500f365e85cSMahesh Ravishankar   if (options.distribution) {
501f365e85cSMahesh Ravishankar     procInfo.resize(
502f365e85cSMahesh Ravishankar         iteratorTypes.size(),
503f365e85cSMahesh Ravishankar         linalg::ProcInfo{nullptr, nullptr, linalg::DistributionMethod::None});
50423bd2e96SMatthias Springer     // Collect loop ranges of tiled loops, loops that are parallel.
505f365e85cSMahesh Ravishankar     SmallVector<Range> parallelLoopRanges;
506e0568fa7SAdrian Kuegel     for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
507f365e85cSMahesh Ravishankar       if (!isParallelIterator(iteratorType.value()))
508f365e85cSMahesh Ravishankar         break;
509f365e85cSMahesh Ravishankar       parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
510f365e85cSMahesh Ravishankar     }
511f365e85cSMahesh Ravishankar     auto returnedProcInfo =
512f365e85cSMahesh Ravishankar         options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges);
513f365e85cSMahesh Ravishankar     unsigned procIdIdx = 0;
514f365e85cSMahesh Ravishankar     // Update the distribution information for the loops.
515e0568fa7SAdrian Kuegel     for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
516f365e85cSMahesh Ravishankar       if (!isParallelIterator(iteratorType.value()))
517f365e85cSMahesh Ravishankar         break;
518f365e85cSMahesh Ravishankar       procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++];
519f365e85cSMahesh Ravishankar     }
520f365e85cSMahesh Ravishankar   }
521f365e85cSMahesh Ravishankar 
522c694588fSMaheshRavishankar   // 2. Create the tiled loops.
52358ddeba3SAlexander Belyaev   LinalgOp res = op;
524a3adcba6SNicolas Vasilache   SmallVector<Value, 4> ivs, tensorResults;
52516488dc3STobias Gysi   auto tiledLoopBodyBuilder =
5264a661602SNicolas Vasilache       [&](OpBuilder &builder, Location loc, ValueRange localIvs,
52716488dc3STobias Gysi           ValueRange operandValuesToUse) -> scf::ValueVector {
528b4bc72afSAlex Zinenko     ivs.assign(localIvs.begin(), localIvs.end());
529f60bbb6cSJose Ignacio Gomez 
530a3adcba6SNicolas Vasilache     // When an `interchangeVector` is present, it has been applied to the
531a3adcba6SNicolas Vasilache     // loop ranges and the iterator types. Apply its inverse to the
532a3adcba6SNicolas Vasilache     // resulting loop `ivs` to match the op definition.
533a3adcba6SNicolas Vasilache     SmallVector<Value, 4> interchangedIvs;
53423bd2e96SMatthias Springer     if (!options.interchangeVector.empty()) {
53523bd2e96SMatthias Springer       for (AffineExpr result : invPermutationMap.getResults())
53623bd2e96SMatthias Springer         interchangedIvs.push_back(
5371609f1c2Slong.chen             ivs[cast<AffineDimExpr>(result).getPosition()]);
53823bd2e96SMatthias Springer     } else {
539a3adcba6SNicolas Vasilache       interchangedIvs.assign(ivs.begin(), ivs.end());
54023bd2e96SMatthias Springer     }
541f60bbb6cSJose Ignacio Gomez 
54216488dc3STobias Gysi     // Tile the `operandValuesToUse` that either match the `op` operands
54316488dc3STobias Gysi     // themselves or the tile loop arguments forwarding them.
54416488dc3STobias Gysi     assert(operandValuesToUse.size() ==
545a7cccb9cSAlexander Belyaev                static_cast<size_t>(op->getNumOperands()) &&
54616488dc3STobias Gysi            "expect the number of operands and inputs and outputs to match");
54716488dc3STobias Gysi     SmallVector<Value> valuesToTile = operandValuesToUse;
548e99fae89SAlex Zinenko     SmallVector<OpFoldResult> sizeBounds =
54926821f75SAlex Zinenko         makeComposedFoldedMultiResultAffineApply(b, loc, shapeSizesToLoopsMap,
55026821f75SAlex Zinenko                                                  allShapeSizes);
551e99fae89SAlex Zinenko     SmallVector<Value> tiledOperands = makeTiledShapes(
552e99fae89SAlex Zinenko         b, loc, op, valuesToTile, getAsOpFoldResult(interchangedIvs), tileSizes,
553e99fae89SAlex Zinenko         sizeBounds,
554e99fae89SAlex Zinenko         /*omitPartialTileCheck=*/false);
555a3adcba6SNicolas Vasilache 
556ff6e5508SAlex Zinenko     SmallVector<Type> resultTensorTypes =
557ff6e5508SAlex Zinenko         getTensorOutputTypes(op, tiledOperands);
558f286af29SAlexander Belyaev     res = clone(b, op, resultTensorTypes, tiledOperands);
559ff6e5508SAlex Zinenko     tensorResults =
560ff6e5508SAlex Zinenko         insertSlicesBack(builder, loc, op, tiledOperands, res->getResults());
56158ddeba3SAlexander Belyaev     return scf::ValueVector(tensorResults.begin(), tensorResults.end());
56284a880e1SNicolas Vasilache   };
56384a880e1SNicolas Vasilache   GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
564f365e85cSMahesh Ravishankar                                  tiledLoopBodyBuilder, procInfo);
565b6281940SNicolas Vasilache 
566d69bccf1STobias Gysi   // 3. Transform IndexOp results w.r.t. the tiling.
56758ddeba3SAlexander Belyaev   transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
568bae8a7a7SAlexander Belyaev 
569c694588fSMaheshRavishankar   // 4. Gather the newly created loops and return them with the new op.
5700da755dfSAlexander Belyaev   SmallVector<Operation *, 8> loops;
571b6281940SNicolas Vasilache   loops.reserve(ivs.size());
572a5bfd32cSLei Zhang   for (auto iv : ivs) {
5735550c821STres Popp     if (isa<BlockArgument>(iv)) {
5745550c821STres Popp       loops.push_back(cast<BlockArgument>(iv).getOwner()->getParentOp());
575a5bfd32cSLei Zhang       assert(loops.back() && "no owner found for induction variable!");
57641d41200SMaheshRavishankar     } else {
57741d41200SMaheshRavishankar       // TODO: Instead of doing this, try to recover the ops used instead of the
57841d41200SMaheshRavishankar       // loop.
57941d41200SMaheshRavishankar       loops.push_back(nullptr);
58041d41200SMaheshRavishankar     }
581a5bfd32cSLei Zhang   }
582a3adcba6SNicolas Vasilache 
583a3adcba6SNicolas Vasilache   // 5. Get the tensor results from the outermost loop if available. Otherwise
584a3adcba6SNicolas Vasilache   // use the previously captured `tensorResults`.
585a3adcba6SNicolas Vasilache   Operation *outermostLoop = nullptr;
586a3adcba6SNicolas Vasilache   for (Operation *loop : loops)
587a3adcba6SNicolas Vasilache     if ((outermostLoop = loop))
588a3adcba6SNicolas Vasilache       break;
589a3adcba6SNicolas Vasilache 
59058ddeba3SAlexander Belyaev   return TiledLinalgOp{
59158ddeba3SAlexander Belyaev       res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
592b6281940SNicolas Vasilache }
593b6281940SNicolas Vasilache 
594eb2f946eSAlexander Belyaev FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
595eb2f946eSAlexander Belyaev     RewriterBase &b, PartialReductionOpInterface op,
596eb2f946eSAlexander Belyaev     ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes,
59722426110SRamkumar Ramachandra     std::optional<ArrayAttr> mapping) {
598f7fda6baSThomas Raoux   Location loc = op.getLoc();
599f7fda6baSThomas Raoux   OpBuilder::InsertionGuard g(b);
60006ca5c81SNicolas Vasilache 
601f7fda6baSThomas Raoux   // Ops implementing PartialReductionOpInterface are expected to implement
602f7fda6baSThomas Raoux   // TilingInterface.
60306ca5c81SNicolas Vasilache   // TODO: proper core mechanism to tie interfaces together.
604f7fda6baSThomas Raoux   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
60506ca5c81SNicolas Vasilache 
60606ca5c81SNicolas Vasilache   // Ops implementing PartialReductionOpInterface are not necessarily expected
60706ca5c81SNicolas Vasilache   // to implement TilingInterface.. This cast is unsafe atm.
60806ca5c81SNicolas Vasilache   // TODO: proper core mechanism to tie interfaces together.
60906ca5c81SNicolas Vasilache   // TODO: this function requires a pair of interfaces ..
61006ca5c81SNicolas Vasilache   auto destinationStyleOp =
61106ca5c81SNicolas Vasilache       dyn_cast<DestinationStyleOpInterface>(op.getOperation());
61206ca5c81SNicolas Vasilache   if (!destinationStyleOp)
61306ca5c81SNicolas Vasilache     return b.notifyMatchFailure(op, "not a destination style op");
61406ca5c81SNicolas Vasilache 
61506ca5c81SNicolas Vasilache   // Actually this only work for Linalg ops atm.
61606ca5c81SNicolas Vasilache   auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
61706ca5c81SNicolas Vasilache   if (!linalgOp)
61806ca5c81SNicolas Vasilache     return b.notifyMatchFailure(op, "not a linalg op");
61906ca5c81SNicolas Vasilache 
620f7fda6baSThomas Raoux   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
621f7fda6baSThomas Raoux   if (op->getNumResults() != 1)
622f7fda6baSThomas Raoux     return b.notifyMatchFailure(
623f7fda6baSThomas Raoux         op, "don't support ops with multiple results for now");
62406ca5c81SNicolas Vasilache 
625f7fda6baSThomas Raoux   SmallVector<utils::IteratorType> iterators =
626f7fda6baSThomas Raoux       tilingInterfaceOp.getLoopIteratorTypes();
627f7fda6baSThomas Raoux   SmallVector<unsigned> redDims;
62806ca5c81SNicolas Vasilache   linalgOp.getReductionDims(redDims);
629f7fda6baSThomas Raoux   if (redDims.size() != 1)
630f7fda6baSThomas Raoux     return b.notifyMatchFailure(
631f7fda6baSThomas Raoux         op, "only support ops with one reduction dimension.");
632f7fda6baSThomas Raoux   if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
633f7fda6baSThomas Raoux     return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
634f7fda6baSThomas Raoux                                     "many elements as number of threads");
635f7fda6baSThomas Raoux   int reductionDim = static_cast<int>(redDims.front());
63606ca5c81SNicolas Vasilache 
637faac8989SAlex Zinenko   if (redDims.front() >= numThreads.size())
638faac8989SAlex Zinenko     return b.notifyMatchFailure(
639faac8989SAlex Zinenko         op, "reduction dimension must be mapped to threads");
640faac8989SAlex Zinenko 
64106ca5c81SNicolas Vasilache   // 1. Create the inital tensor value.
6429329b20dSKunwar Grover   FailureOr<SmallVector<Value>> maybeInitTensors =
643f7fda6baSThomas Raoux       op.generateInitialTensorForPartialReduction(b, loc, numThreads,
644f7fda6baSThomas Raoux                                                   reductionDim);
6459329b20dSKunwar Grover   if (failed(maybeInitTensors))
6469329b20dSKunwar Grover     return b.notifyMatchFailure(
6479329b20dSKunwar Grover         op, "Failed to create inital tensors for partial reduction");
6489329b20dSKunwar Grover   SmallVector<Value> &initTensors = maybeInitTensors.value();
649f7fda6baSThomas Raoux 
650f7fda6baSThomas Raoux   // Gather destination tensors.
651f7fda6baSThomas Raoux   SmallVector<Value> dest;
652f7fda6baSThomas Raoux   if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
653f7fda6baSThomas Raoux     return b.notifyMatchFailure(op, "failed to get destination tensors");
654f7fda6baSThomas Raoux 
655f7fda6baSThomas Raoux   Operation *tiledOp = nullptr;
656f7fda6baSThomas Raoux 
657f4d75863SJakub Kuderski   SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
658f4d75863SJakub Kuderski       numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
659f7fda6baSThomas Raoux   SmallVector<Value> materializedNonZeroNumThreads =
660c888a0ceSNicolas Vasilache       getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
661f7fda6baSThomas Raoux 
662eb2f946eSAlexander Belyaev   // 2. Create the ForallOp with an empty region.
663eb2f946eSAlexander Belyaev   scf::ForallOp forallOp = b.create<scf::ForallOp>(
6649329b20dSKunwar Grover       loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
6659329b20dSKunwar Grover       mapping);
666f7fda6baSThomas Raoux 
66706ca5c81SNicolas Vasilache   // 3. Calculate the tile offsets and sizes for the subsequent loop that will
668eb2f946eSAlexander Belyaev   // be nested under `forallOp`.
669f7fda6baSThomas Raoux   SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
670eb2f946eSAlexander Belyaev   calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain,
671f7fda6baSThomas Raoux                                /*omitTileOffsetBoundsCheck =*/false,
672eb2f946eSAlexander Belyaev                                /*nominalTileSizes=*/std::nullopt, tiledOffsets,
673eb2f946eSAlexander Belyaev                                tiledSizes);
674f7fda6baSThomas Raoux 
6759329b20dSKunwar Grover   // 4b. Clone the tileable op and update its destination operands to use the
676eb2f946eSAlexander Belyaev   // output bbArgs of the ForallOp.
6771da04b04SMahesh Ravishankar   SmallVector<Value> tilingResults;
67876ead96cSMaheshRavishankar   ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
67906ca5c81SNicolas Vasilache   {
680eb2f946eSAlexander Belyaev     // 4.a. RAII guard, inserting within forallOp, before terminator.
68106ca5c81SNicolas Vasilache     OpBuilder::InsertionGuard g(b);
682eb2f946eSAlexander Belyaev     b.setInsertionPoint(forallOp.getTerminator());
68306ca5c81SNicolas Vasilache 
68406ca5c81SNicolas Vasilache     SmallVector<Value> tiledDpsInitOperands;
6850b2197b0SMatthias Springer     for (Value initOperand : destinationStyleOp.getDpsInits()) {
6860b2197b0SMatthias Springer       auto *it = llvm::find(dest, initOperand);
687f7fda6baSThomas Raoux       assert(it != dest.end() && "dest operand not found in dest");
688f7fda6baSThomas Raoux       unsigned destNum = std::distance(dest.begin(), it);
689f7fda6baSThomas Raoux       SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
69006ca5c81SNicolas Vasilache       SmallVector<OpFoldResult> outOffsets(numThreads.size(),
69106ca5c81SNicolas Vasilache                                            b.getIndexAttr(0));
692f7fda6baSThomas Raoux       SmallVector<OpFoldResult> sizes = tiledSizes;
693f7fda6baSThomas Raoux       sizes[reductionDim] = b.getIndexAttr(1);
6946b4c1228Ssrcarroll       outOffsets[reductionDim] = forallOp.getInductionVars()[0];
695f7fda6baSThomas Raoux       // TODO: use SubsetExtractOpInterface once it is available.
69606ca5c81SNicolas Vasilache       tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
6970b2197b0SMatthias Springer           loc, cast<RankedTensorType>(initOperand.getType()),
69806ca5c81SNicolas Vasilache           destBbArgs[destNum], outOffsets, sizes, strides));
699f7fda6baSThomas Raoux     }
70006ca5c81SNicolas Vasilache 
70106ca5c81SNicolas Vasilache     // 4.b. Clone the op and update init operands.
7024d67b278SJeff Niu     // We cannot use a IRMapping here because it can replace
70306ca5c81SNicolas Vasilache     // different OpOperands with the same value.
70406ca5c81SNicolas Vasilache     Operation *clonedOp = b.clone(*op.getOperation());
7055fcf907bSMatthias Springer     b.modifyOpInPlace(clonedOp, [&]() {
70606ca5c81SNicolas Vasilache       for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
7070b2197b0SMatthias Springer                cast<DestinationStyleOpInterface>(clonedOp).getDpsInitsMutable(),
70806ca5c81SNicolas Vasilache                tiledDpsInitOperands)) {
7090b2197b0SMatthias Springer         initOperandPtr.set(tiledInitValue);
71006ca5c81SNicolas Vasilache       }
71106ca5c81SNicolas Vasilache     });
712f7fda6baSThomas Raoux 
713f7fda6baSThomas Raoux     // 5. Tile the cloned op and delete the clone.
714f7fda6baSThomas Raoux     if (tileSizes.empty()) {
715809e3d8cSMahesh Ravishankar       FailureOr<TilingResult> tilingResult =
71606ca5c81SNicolas Vasilache           cast<TilingInterface>(clonedOp).getTiledImplementation(
71706ca5c81SNicolas Vasilache               b, tiledOffsets, tiledSizes);
718242b3091SNicolas Vasilache       if (failed(tilingResult))
719242b3091SNicolas Vasilache         return clonedOp->emitError("Failed to tile op: ");
720242b3091SNicolas Vasilache       if (tilingResult->tiledOps.size() != 1) {
721242b3091SNicolas Vasilache         return clonedOp->emitError("expected a single produced tiled op, got ")
722242b3091SNicolas Vasilache                << tilingResult->tiledOps.size();
723242b3091SNicolas Vasilache       }
724809e3d8cSMahesh Ravishankar       tiledOp = tilingResult->tiledOps.front();
725809e3d8cSMahesh Ravishankar       tilingResults = tilingResult->tiledValues;
726f7fda6baSThomas Raoux     } else {
727f7fda6baSThomas Raoux       LinalgTilingOptions options;
72806ca5c81SNicolas Vasilache       FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
72906ca5c81SNicolas Vasilache           b, cast<LinalgOp>(clonedOp), tileSizes, options);
73006ca5c81SNicolas Vasilache       if (failed(maybeTiled))
73106ca5c81SNicolas Vasilache         return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
73206ca5c81SNicolas Vasilache 
733eb2f946eSAlexander Belyaev       SmallVector<Value> ids = forallOp.getInductionVars();
73406ca5c81SNicolas Vasilache       mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
735f7fda6baSThomas Raoux                             materializedNonZeroNumThreads);
736242b3091SNicolas Vasilache       if (maybeTiled->loops.size() != 1) {
737242b3091SNicolas Vasilache         return clonedOp->emitError("expected a single produced loop");
738242b3091SNicolas Vasilache       }
73906ca5c81SNicolas Vasilache       tiledOp = maybeTiled->op;
74006ca5c81SNicolas Vasilache       tilingResults = maybeTiled->loops.front()->getResults();
741f7fda6baSThomas Raoux     }
74206ca5c81SNicolas Vasilache 
743f7fda6baSThomas Raoux     b.eraseOp(clonedOp);
74406ca5c81SNicolas Vasilache   }
745f7fda6baSThomas Raoux 
746f7fda6baSThomas Raoux   // 6. Insert the partial reductions back into a new tensor.
74706ca5c81SNicolas Vasilache   for (auto [index, result, bbArg] : llvm::zip(
74806ca5c81SNicolas Vasilache            llvm::seq<unsigned>(0, dest.size()), tilingResults, destBbArgs)) {
74906ca5c81SNicolas Vasilache     // 6.a. Partial subset information is inserted just before the terminator.
75006ca5c81SNicolas Vasilache     OpBuilder::InsertionGuard g(b);
751eb2f946eSAlexander Belyaev     b.setInsertionPoint(forallOp.getTerminator());
75206ca5c81SNicolas Vasilache 
753f7fda6baSThomas Raoux     SmallVector<OpFoldResult> resultOffsets, resultSizes;
754f7fda6baSThomas Raoux     if (failed(tilingInterfaceOp.getResultTilePosition(
755f7fda6baSThomas Raoux             b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
756f7fda6baSThomas Raoux       return op->emitOpError("output offsets couldn't be calculated");
757f7fda6baSThomas Raoux     SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
758f7fda6baSThomas Raoux     int64_t offIdx = 0;
759f7fda6baSThomas Raoux     int64_t sizeIdx = 0;
760f7fda6baSThomas Raoux     for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
761f7fda6baSThomas Raoux       if (i == reductionDim) {
7626b4c1228Ssrcarroll         resultOffsetsRank.push_back(forallOp.getInductionVars()[0]);
763f7fda6baSThomas Raoux         resultSizesRank.push_back(b.getIndexAttr(1));
764f7fda6baSThomas Raoux         continue;
765f7fda6baSThomas Raoux       }
766f7fda6baSThomas Raoux       resultOffsetsRank.push_back(resultOffsets[offIdx++]);
767f7fda6baSThomas Raoux       resultSizesRank.push_back(resultSizes[sizeIdx++]);
768f7fda6baSThomas Raoux     }
769f7fda6baSThomas Raoux     SmallVector<OpFoldResult> strides(resultSizesRank.size(),
770f7fda6baSThomas Raoux                                       b.getIndexAttr(1));
77106ca5c81SNicolas Vasilache 
77206ca5c81SNicolas Vasilache     // 6.b. Parallel insertions are inserted at the end of the combining
77306ca5c81SNicolas Vasilache     // terminator.
774eb2f946eSAlexander Belyaev     b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
775f7fda6baSThomas Raoux     b.create<tensor::ParallelInsertSliceOp>(
776f7fda6baSThomas Raoux         loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
777f7fda6baSThomas Raoux   }
77806ca5c81SNicolas Vasilache 
779f7fda6baSThomas Raoux   // 7. Merge the partial reductions.
780eb2f946eSAlexander Belyaev   b.setInsertionPointAfter(forallOp);
781b99d0b34SMaheshRavishankar   FailureOr<MergeResult> mergeResult =
782eb2f946eSAlexander Belyaev       op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
783b99d0b34SMaheshRavishankar   if (failed(mergeResult)) {
784b99d0b34SMaheshRavishankar     return failure();
785b99d0b34SMaheshRavishankar   }
786b99d0b34SMaheshRavishankar   b.replaceOp(op, mergeResult->replacements);
78706ca5c81SNicolas Vasilache 
78806ca5c81SNicolas Vasilache   // 8. Return.
789eb2f946eSAlexander Belyaev   ForallReductionTilingResult results;
7909329b20dSKunwar Grover   results.initialValues = initTensors;
791eb2f946eSAlexander Belyaev   results.loops = forallOp;
792b99d0b34SMaheshRavishankar   results.parallelTiledOps.push_back(tiledOp);
793b99d0b34SMaheshRavishankar   results.mergeOps.append(mergeResult->mergeOps);
794f7fda6baSThomas Raoux   return results;
795f7fda6baSThomas Raoux }
796f7fda6baSThomas Raoux 
797c694588fSMaheshRavishankar template <typename LoopTy>
798489fec27SNicolas Vasilache FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
7994a661602SNicolas Vasilache     RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
800c694588fSMaheshRavishankar   OpBuilder::InsertionGuard g(b);
801c694588fSMaheshRavishankar   b.setInsertionPoint(op);
802c694588fSMaheshRavishankar 
8034643fd27SNicolas Vasilache   if (!options.tileSizeComputationFunction)
804489fec27SNicolas Vasilache     return failure();
8054643fd27SNicolas Vasilache 
806c694588fSMaheshRavishankar   // Enforce the convention that "tiling by zero" skips tiling a particular
807c694588fSMaheshRavishankar   // dimension. This convention is significantly simpler to handle instead of
808c694588fSMaheshRavishankar   // adjusting affine maps to account for missing dimensions.
809c694588fSMaheshRavishankar   auto nLoops = op.getNumLoops();
810e99fae89SAlex Zinenko   SmallVector<OpFoldResult> tileSizeVector =
811e99fae89SAlex Zinenko       getAsOpFoldResult(options.tileSizeComputationFunction(b, op));
812c694588fSMaheshRavishankar   if (tileSizeVector.size() < nLoops) {
813e99fae89SAlex Zinenko     tileSizeVector.append(nLoops - tileSizeVector.size(), b.getIndexAttr(0));
814c694588fSMaheshRavishankar   }
815c694588fSMaheshRavishankar 
816c694588fSMaheshRavishankar   return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
817c694588fSMaheshRavishankar }
818c694588fSMaheshRavishankar 
819489fec27SNicolas Vasilache FailureOr<TiledLinalgOp>
8204a661602SNicolas Vasilache mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op,
821004a3d4fSNicolas Vasilache                            const LinalgTilingOptions &options) {
822c694588fSMaheshRavishankar   switch (options.loopType) {
823c694588fSMaheshRavishankar   case LinalgTilingLoopType::Loops:
824004a3d4fSNicolas Vasilache     return tileLinalgOpImpl<scf::ForOp>(b, op, options);
825c694588fSMaheshRavishankar   case LinalgTilingLoopType::ParallelLoops:
826004a3d4fSNicolas Vasilache     return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
827c694588fSMaheshRavishankar   default:;
828c694588fSMaheshRavishankar   }
829489fec27SNicolas Vasilache   return failure();
8300da755dfSAlexander Belyaev }
8310da755dfSAlexander Belyaev 
832004a3d4fSNicolas Vasilache namespace {
833004a3d4fSNicolas Vasilache /// Helper classes for type list expansion.
834004a3d4fSNicolas Vasilache template <typename... OpTypes>
835004a3d4fSNicolas Vasilache class CanonicalizationPatternList;
836004a3d4fSNicolas Vasilache 
837004a3d4fSNicolas Vasilache template <>
838004a3d4fSNicolas Vasilache class CanonicalizationPatternList<> {
839004a3d4fSNicolas Vasilache public:
840dc4e913bSChris Lattner   static void insert(RewritePatternSet &patterns) {}
841004a3d4fSNicolas Vasilache };
842004a3d4fSNicolas Vasilache 
843004a3d4fSNicolas Vasilache template <typename OpTy, typename... OpTypes>
844004a3d4fSNicolas Vasilache class CanonicalizationPatternList<OpTy, OpTypes...> {
845004a3d4fSNicolas Vasilache public:
846dc4e913bSChris Lattner   static void insert(RewritePatternSet &patterns) {
8473a506b31SChris Lattner     OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
8483a506b31SChris Lattner     CanonicalizationPatternList<OpTypes...>::insert(patterns);
849004a3d4fSNicolas Vasilache   }
850004a3d4fSNicolas Vasilache };
851004a3d4fSNicolas Vasilache } // namespace
852004a3d4fSNicolas Vasilache 
853dc4e913bSChris Lattner RewritePatternSet
854004a3d4fSNicolas Vasilache mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
855dc4e913bSChris Lattner   RewritePatternSet patterns(ctx);
8563a506b31SChris Lattner   populateLinalgTilingCanonicalizationPatterns(patterns);
8572c66b6ecSNicolas Vasilache   return patterns;
8582c66b6ecSNicolas Vasilache }
8592c66b6ecSNicolas Vasilache 
8602c66b6ecSNicolas Vasilache void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
861dc4e913bSChris Lattner     RewritePatternSet &patterns) {
8623a506b31SChris Lattner   auto *ctx = patterns.getContext();
8634c48f016SMatthias Springer   affine::AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
8644c48f016SMatthias Springer   affine::AffineForOp::getCanonicalizationPatterns(patterns, ctx);
8654c48f016SMatthias Springer   affine::AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
8664c48f016SMatthias Springer   affine::AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
867a54f4eaeSMogball   arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
868a3f42594SLei Zhang 
869a3f42594SLei Zhang   memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
870a3f42594SLei Zhang   memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
871a3f42594SLei Zhang 
872004a3d4fSNicolas Vasilache   scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
873004a3d4fSNicolas Vasilache   scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
874a3f42594SLei Zhang 
875a3f42594SLei Zhang   tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
87681ca5aa4SMatthias Springer   tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx);
877060208b4SMatthias Springer   tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
878a0e02018SMatthias Springer   tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
879fd0c6f53SAlexander Belyaev   tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
88066e27082SRiver Riddle   ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
881a3f42594SLei Zhang 
882004a3d4fSNicolas Vasilache   CanonicalizationPatternList<
883004a3d4fSNicolas Vasilache #define GET_OP_LIST
884004a3d4fSNicolas Vasilache #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
8853a506b31SChris Lattner       >::insert(patterns);
886baecae83SAlexander Belyaev }
887