1b6281940SNicolas Vasilache //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===// 2b6281940SNicolas Vasilache // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b6281940SNicolas Vasilache // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 8b6281940SNicolas Vasilache // 9b6281940SNicolas Vasilache // This file implements the linalg dialect Tiling pass. 10b6281940SNicolas Vasilache // 11b6281940SNicolas Vasilache //===----------------------------------------------------------------------===// 12b6281940SNicolas Vasilache 1367d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h" 141fc096afSMehdi Amini 15e99fae89SAlex Zinenko #include "mlir/Dialect/Affine/IR/AffineOps.h" 16f7fda6baSThomas Raoux #include "mlir/Dialect/Affine/LoopUtils.h" 17abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/Utils/Utils.h" 183963b4d0SAlex Zinenko #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 1967d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h" 20b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h" 21307cfdf5SNicolas Vasilache #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 22e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h" 238b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h" 24129d6e55SSean Silva #include "mlir/Dialect/Tensor/IR/Tensor.h" 25f71f9958SDiego Caballero #include "mlir/Dialect/Utils/IndexingUtils.h" 26b6281940SNicolas Vasilache #include "mlir/IR/AffineExpr.h" 27b6281940SNicolas Vasilache #include "mlir/IR/AffineMap.h" 2806ca5c81SNicolas Vasilache #include "mlir/IR/BuiltinOps.h" 2906ca5c81SNicolas Vasilache #include "mlir/IR/ValueRange.h" 30b6281940SNicolas Vasilache #include "mlir/Transforms/FoldUtils.h" 31b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 3206ca5c81SNicolas Vasilache #include "llvm/ADT/STLExtras.h" 33b6281940SNicolas Vasilache #include "llvm/Support/CommandLine.h" 3467d0d7acSMichele Scuttari #include <utility> 3567d0d7acSMichele Scuttari 3667d0d7acSMichele Scuttari namespace mlir { 3767d0d7acSMichele Scuttari #define GEN_PASS_DEF_LINALGTILINGPASS 3867d0d7acSMichele Scuttari #include "mlir/Dialect/Linalg/Passes.h.inc" 3967d0d7acSMichele Scuttari } // namespace mlir 40b6281940SNicolas Vasilache 41b6281940SNicolas Vasilache using namespace mlir; 424c48f016SMatthias Springer using namespace mlir::affine; 43b6281940SNicolas Vasilache using namespace mlir::linalg; 44c25b20c0SAlex Zinenko using namespace mlir::scf; 45b6281940SNicolas Vasilache 46039b969bSMichele Scuttari #define DEBUG_TYPE "linalg-tiling" 47039b969bSMichele Scuttari 48c9620389SAlexander Belyaev std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap> 49c9620389SAlexander Belyaev mlir::linalg::makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, 50e99fae89SAlex Zinenko ArrayRef<OpFoldResult> allShapeSizes, 51e99fae89SAlex Zinenko ArrayRef<OpFoldResult> allTileSizes) { 52b6281940SNicolas Vasilache assert(allTileSizes.size() == map.getNumResults()); 53a3adcba6SNicolas Vasilache // Apply `map` to get shape sizes in loop order. 54e99fae89SAlex Zinenko SmallVector<OpFoldResult> shapeSizes = 55e99fae89SAlex Zinenko makeComposedFoldedMultiResultAffineApply(b, loc, map, allShapeSizes); 565262865aSKazu Hirata SmallVector<OpFoldResult> tileSizes(allTileSizes); 57b6281940SNicolas Vasilache 58b6281940SNicolas Vasilache // Traverse the tile sizes, which are in loop order, erase zeros everywhere. 59bae8a7a7SAlexander Belyaev LoopIndexToRangeIndexMap loopIndexToRangeIndex; 60bae8a7a7SAlexander Belyaev for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) { 61cb7bda2aSMatthias Springer if (getConstantIntValue(tileSizes[idx - zerosCount]) == 62cb7bda2aSMatthias Springer static_cast<int64_t>(0)) { 63a3adcba6SNicolas Vasilache shapeSizes.erase(shapeSizes.begin() + idx - zerosCount); 64bae8a7a7SAlexander Belyaev tileSizes.erase(tileSizes.begin() + idx - zerosCount); 65bae8a7a7SAlexander Belyaev ++zerosCount; 66bae8a7a7SAlexander Belyaev continue; 67b6281940SNicolas Vasilache } 68bae8a7a7SAlexander Belyaev loopIndexToRangeIndex[idx] = idx - zerosCount; 69b6281940SNicolas Vasilache } 70b6281940SNicolas Vasilache 71b6281940SNicolas Vasilache // Create a new range with the applied tile sizes. 72e3de249aSNicolas Vasilache SmallVector<Range, 4> res; 73004a3d4fSNicolas Vasilache for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) 7470e99f38SAlex Zinenko res.push_back(Range{b.getIndexAttr(0), shapeSizes[idx], tileSizes[idx]}); 75bae8a7a7SAlexander Belyaev return std::make_tuple(res, loopIndexToRangeIndex); 76b6281940SNicolas Vasilache } 77b6281940SNicolas Vasilache 78c9620389SAlexander Belyaev void mlir::linalg::transformIndexOps( 79c9620389SAlexander Belyaev RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs, 808ea5d190STobias Gysi const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { 8190b7817eSTobias Gysi SmallVector<Value> allIvs(op.getNumLoops(), nullptr); 828c258fdaSJakub Kuderski for (auto en : enumerate(allIvs)) { 8390b7817eSTobias Gysi auto rangeIndex = loopIndexToRangeIndex.find(en.index()); 848ea5d190STobias Gysi if (rangeIndex == loopIndexToRangeIndex.end()) 858ea5d190STobias Gysi continue; 8690b7817eSTobias Gysi en.value() = ivs[rangeIndex->second]; 878ea5d190STobias Gysi } 88e99fae89SAlex Zinenko offsetIndices(b, op, getAsOpFoldResult(allIvs)); 898ea5d190STobias Gysi } 908ea5d190STobias Gysi 913963b4d0SAlex Zinenko /// Asserts that the given index-typed value is strictly positive. If the value 923963b4d0SAlex Zinenko /// is an attribute, asserts at compile time, otherwise emits an assertion 933963b4d0SAlex Zinenko /// checked at runtime. 943963b4d0SAlex Zinenko static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b, 953963b4d0SAlex Zinenko OpFoldResult value) { 9668f58812STres Popp if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) { 975550c821STres Popp assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() && 983963b4d0SAlex Zinenko "expected strictly positive tile size and divisor"); 993963b4d0SAlex Zinenko return; 1003963b4d0SAlex Zinenko } 1013963b4d0SAlex Zinenko 1023963b4d0SAlex Zinenko Value zero = b.create<arith::ConstantIndexOp>(0); 1033963b4d0SAlex Zinenko Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, 104*4f279a57SKazu Hirata cast<Value>(value), zero); 1053963b4d0SAlex Zinenko b.create<cf::AssertOp>( 1063963b4d0SAlex Zinenko condition, 1073963b4d0SAlex Zinenko b.getStringAttr("expected strictly positive tile size and divisor")); 1083963b4d0SAlex Zinenko } 1093963b4d0SAlex Zinenko 110a9efcbf4Smuneebkhan85 FailureOr<StaticContinuousTileSizeSpecification> 111a9efcbf4Smuneebkhan85 mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, 112a9efcbf4Smuneebkhan85 unsigned dimension, 113a9efcbf4Smuneebkhan85 unsigned targetSize) { 114a9efcbf4Smuneebkhan85 115a9efcbf4Smuneebkhan85 assert(!op.hasDynamicShape() && 116a9efcbf4Smuneebkhan85 "cannot compute static multi-tile sizes for an op with dynamic shape"); 117a9efcbf4Smuneebkhan85 assert(targetSize > 0 && "target size must be non-negative"); 118a9efcbf4Smuneebkhan85 assert(dimension < op.getNumLoops() && "dimension overflow"); 119a9efcbf4Smuneebkhan85 120a9efcbf4Smuneebkhan85 StaticContinuousTileSizeSpecification spec; 121a9efcbf4Smuneebkhan85 int64_t loopRange = op.getStaticLoopRanges()[dimension]; 122a9efcbf4Smuneebkhan85 int64_t tripCount = loopRange / targetSize; 123a9efcbf4Smuneebkhan85 124a9efcbf4Smuneebkhan85 unsigned tileSize = targetSize; 125a9efcbf4Smuneebkhan85 126a9efcbf4Smuneebkhan85 spec.tileSizes.push_back(tileSize); 127a9efcbf4Smuneebkhan85 spec.tripCounts.push_back(tripCount); 128a9efcbf4Smuneebkhan85 129a9efcbf4Smuneebkhan85 int64_t remainderChunk = loopRange % targetSize; 130a9efcbf4Smuneebkhan85 131a9efcbf4Smuneebkhan85 while (tileSize > 1 && remainderChunk != 0) { 132a9efcbf4Smuneebkhan85 133a9efcbf4Smuneebkhan85 uint64_t maxPower = llvm::bit_floor(tileSize); 134a9efcbf4Smuneebkhan85 tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower; 135a9efcbf4Smuneebkhan85 136a9efcbf4Smuneebkhan85 tripCount = remainderChunk / tileSize; 137a9efcbf4Smuneebkhan85 138a9efcbf4Smuneebkhan85 if (tripCount > 0) { 139a9efcbf4Smuneebkhan85 spec.tileSizes.push_back(tileSize); 140a9efcbf4Smuneebkhan85 spec.tripCounts.push_back(tripCount); 141a9efcbf4Smuneebkhan85 } 142a9efcbf4Smuneebkhan85 143a9efcbf4Smuneebkhan85 remainderChunk = remainderChunk % tileSize; 144a9efcbf4Smuneebkhan85 } 145a9efcbf4Smuneebkhan85 146a9efcbf4Smuneebkhan85 auto tripCountCheck = [&](SmallVector<int64_t> tileSizes, 147a9efcbf4Smuneebkhan85 SmallVector<int64_t> tripCounts, 148a9efcbf4Smuneebkhan85 int64_t range) -> bool { 149a9efcbf4Smuneebkhan85 int64_t computedRange = 0; 150a9efcbf4Smuneebkhan85 for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts)) 151a9efcbf4Smuneebkhan85 computedRange += tileSize * tripCount; 152a9efcbf4Smuneebkhan85 return range == computedRange; 153a9efcbf4Smuneebkhan85 }; 154a9efcbf4Smuneebkhan85 155a9efcbf4Smuneebkhan85 if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange)) 156a9efcbf4Smuneebkhan85 return failure(); 157a9efcbf4Smuneebkhan85 158a9efcbf4Smuneebkhan85 return spec; 159a9efcbf4Smuneebkhan85 } 160a9efcbf4Smuneebkhan85 161a9efcbf4Smuneebkhan85 FailureOr<ContinuousTileSizeSpecification> 162a9efcbf4Smuneebkhan85 mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, 163a9efcbf4Smuneebkhan85 unsigned dimension, 164a9efcbf4Smuneebkhan85 OpFoldResult targetSize, 165a9efcbf4Smuneebkhan85 bool emitAssertions) { 166a9efcbf4Smuneebkhan85 167a9efcbf4Smuneebkhan85 SmallVector<Range> loopRanges = op.getIterationDomain(builder); 168a9efcbf4Smuneebkhan85 unsigned numLoops = loopRanges.size(); 169a9efcbf4Smuneebkhan85 170a9efcbf4Smuneebkhan85 // Bail out on dimension overflow. 171a9efcbf4Smuneebkhan85 if (dimension >= numLoops) 172a9efcbf4Smuneebkhan85 return failure(); 173a9efcbf4Smuneebkhan85 174a9efcbf4Smuneebkhan85 // The code below works only on values. 175a9efcbf4Smuneebkhan85 Location loc = op->getLoc(); 176a9efcbf4Smuneebkhan85 ImplicitLocOpBuilder b(loc, builder); 177a9efcbf4Smuneebkhan85 if (emitAssertions) { 178a9efcbf4Smuneebkhan85 emitIsPositiveIndexAssertion(b, targetSize); 179a9efcbf4Smuneebkhan85 } 180a9efcbf4Smuneebkhan85 Value targetSizeValue = 181a9efcbf4Smuneebkhan85 getValueOrCreateConstantIndexOp(builder, loc, targetSize); 182a9efcbf4Smuneebkhan85 183a9efcbf4Smuneebkhan85 // Find the trip count of the iteration space dimension for which the tile 184a9efcbf4Smuneebkhan85 // sizes are computed. 185a9efcbf4Smuneebkhan85 Value loopRange = getValueOrCreateConstantIndexOp(b, loc, 186a9efcbf4Smuneebkhan85 loopRanges[dimension].size); 187a9efcbf4Smuneebkhan85 ContinuousTileSizeSpecification spec; 188a9efcbf4Smuneebkhan85 189a9efcbf4Smuneebkhan85 // Compute the tile sizes and the respective numbers of tiles. 190a9efcbf4Smuneebkhan85 AffineExpr s0 = b.getAffineSymbolExpr(0); 191a9efcbf4Smuneebkhan85 AffineExpr s1 = b.getAffineSymbolExpr(1); 192a9efcbf4Smuneebkhan85 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value { 193a9efcbf4Smuneebkhan85 return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs); 194a9efcbf4Smuneebkhan85 }; 195a9efcbf4Smuneebkhan85 196a9efcbf4Smuneebkhan85 Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue}); 197a9efcbf4Smuneebkhan85 Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue}); 198a9efcbf4Smuneebkhan85 199a9efcbf4Smuneebkhan85 OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply( 200a9efcbf4Smuneebkhan85 b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue}); 201a9efcbf4Smuneebkhan85 202a9efcbf4Smuneebkhan85 // emitAssertions above already asserts that targetSize is 203a9efcbf4Smuneebkhan85 // a poistive integer. 204a9efcbf4Smuneebkhan85 uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue); 205a9efcbf4Smuneebkhan85 206a9efcbf4Smuneebkhan85 assert(tileSizeInt > 0 && "target size must be non-negative"); 207a9efcbf4Smuneebkhan85 208a9efcbf4Smuneebkhan85 spec.tileSizes.push_back(targetSizeValue); 209a9efcbf4Smuneebkhan85 spec.tripCounts.push_back(tripCountValue); 210a9efcbf4Smuneebkhan85 211a9efcbf4Smuneebkhan85 while (tileSizeInt > 1) { 212a9efcbf4Smuneebkhan85 uint64_t maxPower = llvm::bit_floor(tileSizeInt); 213a9efcbf4Smuneebkhan85 tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower; 214a9efcbf4Smuneebkhan85 auto constStepOp = 215a9efcbf4Smuneebkhan85 builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt); 216a9efcbf4Smuneebkhan85 tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp}); 217a9efcbf4Smuneebkhan85 218a9efcbf4Smuneebkhan85 tripCountSize = affine::makeComposedFoldedAffineApply( 219a9efcbf4Smuneebkhan85 b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp}); 220a9efcbf4Smuneebkhan85 221a9efcbf4Smuneebkhan85 // Optimization if tripCount can be determined to be zero. 222a9efcbf4Smuneebkhan85 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) { 223a9efcbf4Smuneebkhan85 auto intAttr = cast<IntegerAttr>(attr); 224a9efcbf4Smuneebkhan85 bool isTripCountZero = intAttr.getValue().isZero(); 225a9efcbf4Smuneebkhan85 226a9efcbf4Smuneebkhan85 if (!isTripCountZero) { 227a9efcbf4Smuneebkhan85 spec.tileSizes.push_back(constStepOp); 228a9efcbf4Smuneebkhan85 spec.tripCounts.push_back(tripCountValue); 229a9efcbf4Smuneebkhan85 } 230a9efcbf4Smuneebkhan85 } else { 231a9efcbf4Smuneebkhan85 spec.tileSizes.push_back(constStepOp); 232a9efcbf4Smuneebkhan85 spec.tripCounts.push_back(tripCountValue); 233a9efcbf4Smuneebkhan85 } 234a9efcbf4Smuneebkhan85 235a9efcbf4Smuneebkhan85 remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp}); 236a9efcbf4Smuneebkhan85 } 237a9efcbf4Smuneebkhan85 238a9efcbf4Smuneebkhan85 return spec; 239a9efcbf4Smuneebkhan85 } 240a9efcbf4Smuneebkhan85 24188c5027bSAlex Zinenko FailureOr<StaticMultiSizeSpecification> 24288c5027bSAlex Zinenko mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, 24388c5027bSAlex Zinenko int64_t targetSize, int64_t divisor) { 24488c5027bSAlex Zinenko assert(!op.hasDynamicShape() && 24588c5027bSAlex Zinenko "cannot compute static multi-tile sizes for an op with dynamic shape"); 24688c5027bSAlex Zinenko assert(targetSize > 0 && "target size must be non-negative"); 24788c5027bSAlex Zinenko assert(divisor > 0 && "divisor must be non-negative"); 24888c5027bSAlex Zinenko assert(dimension < op.getNumLoops() && "dimension overflow"); 24988c5027bSAlex Zinenko 25088c5027bSAlex Zinenko StaticMultiSizeSpecification spec; 25188c5027bSAlex Zinenko int64_t tripCount = op.getStaticLoopRanges()[dimension]; 25288c5027bSAlex Zinenko int64_t a = tripCount / divisor; 25388c5027bSAlex Zinenko int64_t t = (targetSize + divisor - 1) / divisor; 25488c5027bSAlex Zinenko int64_t totalTripCount = (a + t - 1) / t; 25588c5027bSAlex Zinenko spec.lowTileSize = (a / totalTripCount) * divisor; 25688c5027bSAlex Zinenko spec.highTileSize = spec.lowTileSize + divisor; 25788c5027bSAlex Zinenko spec.highTripCount = a % totalTripCount; 25888c5027bSAlex Zinenko spec.lowTripCount = totalTripCount - spec.highTripCount; 25988c5027bSAlex Zinenko if (spec.lowTileSize * spec.lowTripCount + 26088c5027bSAlex Zinenko spec.highTileSize * spec.highTripCount != 26188c5027bSAlex Zinenko tripCount) { 26288c5027bSAlex Zinenko return failure(); 26388c5027bSAlex Zinenko } 26488c5027bSAlex Zinenko return spec; 26588c5027bSAlex Zinenko } 26688c5027bSAlex Zinenko 2673963b4d0SAlex Zinenko FailureOr<MultiSizeSpecification> 2683963b4d0SAlex Zinenko mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op, 2693963b4d0SAlex Zinenko unsigned dimension, OpFoldResult targetSize, 2703963b4d0SAlex Zinenko OpFoldResult divisor, bool emitAssertions) { 2713963b4d0SAlex Zinenko // Bail out on dimension overflow. 2723963b4d0SAlex Zinenko if (dimension >= op.getNumLoops()) 2733963b4d0SAlex Zinenko return failure(); 2743963b4d0SAlex Zinenko 2753963b4d0SAlex Zinenko // The code below works only on values. 2764bf84e43SAlexander Belyaev Location loc = op.getLoc(); 2774bf84e43SAlexander Belyaev ImplicitLocOpBuilder b(loc, builder); 2783963b4d0SAlex Zinenko if (emitAssertions) { 2793963b4d0SAlex Zinenko emitIsPositiveIndexAssertion(b, targetSize); 2803963b4d0SAlex Zinenko emitIsPositiveIndexAssertion(b, divisor); 2813963b4d0SAlex Zinenko } 2824bf84e43SAlexander Belyaev Value targetSizeValue = 2834bf84e43SAlexander Belyaev getValueOrCreateConstantIndexOp(builder, loc, targetSize); 2844bf84e43SAlexander Belyaev Value divisorValue = getValueOrCreateConstantIndexOp(builder, loc, divisor); 2853963b4d0SAlex Zinenko 2863963b4d0SAlex Zinenko // Find the trip count of the iteration space dimension for which the tile 2873963b4d0SAlex Zinenko // sizes are computed. 288e99fae89SAlex Zinenko SmallVector<OpFoldResult> allShapes = 2893963b4d0SAlex Zinenko op.createFlatListOfOperandDims(b, b.getLoc()); 2903963b4d0SAlex Zinenko AffineMap shapesToLoops = op.getShapesToLoopsMap(); 291e99fae89SAlex Zinenko SmallVector<OpFoldResult> loopRanges = 29226821f75SAlex Zinenko makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops, 29326821f75SAlex Zinenko allShapes); 294e99fae89SAlex Zinenko Value tripCount = 2954bf84e43SAlexander Belyaev getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]); 2963963b4d0SAlex Zinenko 2973963b4d0SAlex Zinenko // Compute the tile sizes and the respective numbers of tiles. 2983963b4d0SAlex Zinenko AffineExpr s0 = b.getAffineSymbolExpr(0); 2993963b4d0SAlex Zinenko AffineExpr s1 = b.getAffineSymbolExpr(1); 3003963b4d0SAlex Zinenko AffineExpr s2 = b.getAffineSymbolExpr(2); 301efc290ceSMatthias Springer auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value { 302efc290ceSMatthias Springer return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs); 3033963b4d0SAlex Zinenko }; 3043963b4d0SAlex Zinenko Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue}); 3053963b4d0SAlex Zinenko Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue}); 3063963b4d0SAlex Zinenko Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t}); 3073963b4d0SAlex Zinenko Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue}); 3083963b4d0SAlex Zinenko Value v = apply(s0 % s1, {a, d}); 3093963b4d0SAlex Zinenko Value u = apply(s0 - s1, {d, v}); 3103963b4d0SAlex Zinenko 3113963b4d0SAlex Zinenko MultiSizeSpecification spec; 3123963b4d0SAlex Zinenko spec.lowTileSize = s; 3133963b4d0SAlex Zinenko spec.highTileSize = apply(s0 + s1, {s, divisorValue}); 3143963b4d0SAlex Zinenko spec.lowTripCount = u; 3153963b4d0SAlex Zinenko spec.highTripCount = v; 3163963b4d0SAlex Zinenko 3173963b4d0SAlex Zinenko // If requested, emit the check that the tile sizes are computed correctly. 3183963b4d0SAlex Zinenko // For example, for iteration dimension size of 15 and the target size 8 it is 3193963b4d0SAlex Zinenko // impossible to find two tile sizes both divisible by 8 that fully cover the 3203963b4d0SAlex Zinenko // original space dimension. 3213963b4d0SAlex Zinenko if (emitAssertions) { 3223963b4d0SAlex Zinenko AffineExpr s3 = builder.getAffineSymbolExpr(3); 3233963b4d0SAlex Zinenko Value coveredSize = 3243963b4d0SAlex Zinenko apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount, 3253963b4d0SAlex Zinenko spec.highTileSize, spec.highTripCount}); 3263963b4d0SAlex Zinenko Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, 3273963b4d0SAlex Zinenko coveredSize, tripCount); 3283963b4d0SAlex Zinenko b.create<cf::AssertOp>( 3293963b4d0SAlex Zinenko equals, builder.getStringAttr( 3303963b4d0SAlex Zinenko "could not compute dynamic multi-size tile shapes")); 3313963b4d0SAlex Zinenko } 3323963b4d0SAlex Zinenko 3333963b4d0SAlex Zinenko return spec; 3343963b4d0SAlex Zinenko } 3353963b4d0SAlex Zinenko 336297ba167SChristopher Bate /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less 337297ba167SChristopher Bate /// than `iterationSize`. 338297ba167SChristopher Bate static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, 339297ba167SChristopher Bate OpFoldResult numThreads, 340297ba167SChristopher Bate OpFoldResult iterationSize) { 34122426110SRamkumar Ramachandra std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize); 34222426110SRamkumar Ramachandra std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads); 34322426110SRamkumar Ramachandra std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize); 344297ba167SChristopher Bate if (!tileSizeConst || !numThreadsConst || !iterSizeConst) 345297ba167SChristopher Bate return false; 346297ba167SChristopher Bate return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; 34718b92c66SNicolas Vasilache } 34818b92c66SNicolas Vasilache 349e99fae89SAlex Zinenko /// Build an `affine_max` of all the `vals`. 350e99fae89SAlex Zinenko static OpFoldResult buildMax(OpBuilder &b, Location loc, 351e99fae89SAlex Zinenko ArrayRef<OpFoldResult> vals) { 3524c48f016SMatthias Springer return affine::makeComposedFoldedAffineMax( 35326821f75SAlex Zinenko b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), 35426821f75SAlex Zinenko vals); 355e99fae89SAlex Zinenko } 356e99fae89SAlex Zinenko 357e99fae89SAlex Zinenko /// Build an `affine_min` of all the `vals`. 358e99fae89SAlex Zinenko static OpFoldResult buildMin(OpBuilder &b, Location loc, 359e99fae89SAlex Zinenko ArrayRef<OpFoldResult> vals) { 3604c48f016SMatthias Springer return affine::makeComposedFoldedAffineMin( 36126821f75SAlex Zinenko b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()), 36226821f75SAlex Zinenko vals); 363e99fae89SAlex Zinenko } 364e99fae89SAlex Zinenko 36599833cd8SThomas Raoux /// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given 36699833cd8SThomas Raoux /// number of threads. 36799833cd8SThomas Raoux static void calculateTileOffsetsAndSizes( 368eb2f946eSAlexander Belyaev RewriterBase &b, Location loc, scf::ForallOp forallOp, 36999833cd8SThomas Raoux ArrayRef<OpFoldResult> numThreads, SmallVector<Range> loopRanges, 37099833cd8SThomas Raoux bool omitTileOffsetBoundsCheck, 37122426110SRamkumar Ramachandra std::optional<ArrayRef<OpFoldResult>> nominalTileSizes, 37299833cd8SThomas Raoux SmallVector<OpFoldResult> &tiledOffsets, 37399833cd8SThomas Raoux SmallVector<OpFoldResult> &tiledSizes) { 37406ca5c81SNicolas Vasilache OpBuilder::InsertionGuard g(b); 375eb2f946eSAlexander Belyaev b.setInsertionPointToStart(forallOp.getBody(0)); 37606ca5c81SNicolas Vasilache 3776b4c1228Ssrcarroll SmallVector<Value> threadIds = forallOp.getInductionVars(); 378f4d75863SJakub Kuderski SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector( 379f4d75863SJakub Kuderski numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); }); 38018b92c66SNicolas Vasilache int64_t nLoops = loopRanges.size(); 38118b92c66SNicolas Vasilache tiledOffsets.reserve(nLoops); 38218b92c66SNicolas Vasilache tiledSizes.reserve(nLoops); 383297ba167SChristopher Bate for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) { 38418b92c66SNicolas Vasilache bool overflow = loopIdx >= numThreads.size(); 38518b92c66SNicolas Vasilache bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0); 38618b92c66SNicolas Vasilache // Degenerate case: take the whole domain. 38718b92c66SNicolas Vasilache if (overflow || isZero) { 38818b92c66SNicolas Vasilache tiledOffsets.push_back(loopRanges[loopIdx].offset); 38918b92c66SNicolas Vasilache tiledSizes.push_back(loopRanges[loopIdx].size); 39018b92c66SNicolas Vasilache continue; 39118b92c66SNicolas Vasilache } 39218b92c66SNicolas Vasilache 39318b92c66SNicolas Vasilache // Tiled case: compute the offset and size. 394a7a892aeSMehdi Amini AffineExpr i, j, m, n, o; 39518b92c66SNicolas Vasilache bindDims(b.getContext(), i, j); 396a7a892aeSMehdi Amini bindSymbols(b.getContext(), m, n, o); 397e99fae89SAlex Zinenko OpFoldResult size = loopRanges[loopIdx].size; 398e99fae89SAlex Zinenko OpFoldResult offset = loopRanges[loopIdx].offset; 399e99fae89SAlex Zinenko OpFoldResult threadId = threadIds[threadIdIdx]; 40018b92c66SNicolas Vasilache // Symbolic fixed max size per thread. 40118b92c66SNicolas Vasilache // TODO: floor + 0/1 depending on case for better load-balancing. 402297ba167SChristopher Bate OpFoldResult tileSizePerThread = 4036fa6901bSKazu Hirata nominalTileSizes.has_value() 404297ba167SChristopher Bate ? (*nominalTileSizes)[loopIdx] 405297ba167SChristopher Bate : makeComposedFoldedAffineApply( 406a7a892aeSMehdi Amini b, loc, m.ceilDiv(n), 407297ba167SChristopher Bate ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]}); 408297ba167SChristopher Bate 40918b92c66SNicolas Vasilache // Dynamic offset shifted by threadId * maxSizePerThread. 410297ba167SChristopher Bate OpFoldResult offsetPerThread = makeComposedFoldedAffineApply( 411a7a892aeSMehdi Amini b, loc, i + j * m, {offset, threadId, tileSizePerThread}); 41218b92c66SNicolas Vasilache // Dynamic upper-bound depending on the threadId. 413297ba167SChristopher Bate OpFoldResult residualTileSize = makeComposedFoldedAffineApply( 414a7a892aeSMehdi Amini b, loc, i + j * m - n, 415297ba167SChristopher Bate {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size}); 416297ba167SChristopher Bate if (!isConstantIntValue(residualTileSize, 0)) { 417297ba167SChristopher Bate OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply( 418a7a892aeSMehdi Amini b, loc, -i + m, {offsetPerThread, size}); 419e99fae89SAlex Zinenko tileSizePerThread = 420e99fae89SAlex Zinenko buildMin(b, loc, {sizeMinusOffsetPerThread, tileSizePerThread}); 421297ba167SChristopher Bate } 422297ba167SChristopher Bate 42318b92c66SNicolas Vasilache tiledOffsets.push_back(offsetPerThread); 42418b92c66SNicolas Vasilache // TODO: if tileSizePerThread <= 0 early exit. 425297ba167SChristopher Bate if (!omitTileOffsetBoundsCheck && 426297ba167SChristopher Bate !canOmitTileOffsetInBoundsCheck(tileSizePerThread, 427297ba167SChristopher Bate nonZeroNumThreads[threadIdIdx], size)) 428e99fae89SAlex Zinenko tileSizePerThread = 429e99fae89SAlex Zinenko buildMax(b, loc, {b.getIndexAttr(0), tileSizePerThread}); 430297ba167SChristopher Bate 431297ba167SChristopher Bate tiledSizes.push_back(tileSizePerThread); 43218b92c66SNicolas Vasilache ++threadIdIdx; 43318b92c66SNicolas Vasilache } 43499833cd8SThomas Raoux } 43599833cd8SThomas Raoux 4360da755dfSAlexander Belyaev template <typename LoopTy> 437489fec27SNicolas Vasilache static FailureOr<TiledLinalgOp> 438e99fae89SAlex Zinenko tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes, 439c694588fSMaheshRavishankar const LinalgTilingOptions &options) { 44006ca5c81SNicolas Vasilache OpBuilder::InsertionGuard g(b); 44106ca5c81SNicolas Vasilache 442004a3d4fSNicolas Vasilache auto nLoops = op.getNumLoops(); 443004a3d4fSNicolas Vasilache // Initial tile sizes may be too big, only take the first nLoops. 444004a3d4fSNicolas Vasilache tileSizes = tileSizes.take_front(nLoops); 445004a3d4fSNicolas Vasilache 446cb7bda2aSMatthias Springer if (llvm::all_of(tileSizes, [](OpFoldResult ofr) { 447cb7bda2aSMatthias Springer return getConstantIntValue(ofr) == static_cast<int64_t>(0); 448cb7bda2aSMatthias Springer })) { 449526dfe3fSMaheshRavishankar TiledLinalgOp tiledOp; 450526dfe3fSMaheshRavishankar tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation())); 451526dfe3fSMaheshRavishankar tiledOp.tensorResults.assign(tiledOp.op->result_begin(), 452526dfe3fSMaheshRavishankar tiledOp.op->result_end()); 453526dfe3fSMaheshRavishankar return tiledOp; 454526dfe3fSMaheshRavishankar } 455f60bbb6cSJose Ignacio Gomez 456c694588fSMaheshRavishankar // 1. Build the tiled loop ranges. 457e99fae89SAlex Zinenko SmallVector<OpFoldResult> allShapeSizes = 458e99fae89SAlex Zinenko op.createFlatListOfOperandDims(b, op.getLoc()); 45901c44185SNicolas Vasilache AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); 460a3adcba6SNicolas Vasilache if (!shapeSizesToLoopsMap) 461489fec27SNicolas Vasilache return failure(); 462bae8a7a7SAlexander Belyaev 4639fa59e76SBenjamin Kramer auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges( 464a3adcba6SNicolas Vasilache b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); 46542444d0cSMaheshRavishankar 466e6598b05SOleg Shyshkov SmallVector<utils::IteratorType, 4> iteratorTypes; 467c54bc8bdSOleg Shyshkov for (const auto &attr : enumerate(op.getIteratorTypesArray())) { 468c694588fSMaheshRavishankar if (loopIndexToRangeIndex.count(attr.index())) 469c694588fSMaheshRavishankar iteratorTypes.push_back(attr.value()); 470c694588fSMaheshRavishankar } 471c694588fSMaheshRavishankar // If interchangeVector is empty, use the identity. Build the permutation map 472c694588fSMaheshRavishankar // otherwise. 473c694588fSMaheshRavishankar auto invPermutationMap = 474c694588fSMaheshRavishankar AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); 475c694588fSMaheshRavishankar if (!options.interchangeVector.empty()) { 476c694588fSMaheshRavishankar // Based on the pruned iterations (due to zero tile size), recompute the 477c694588fSMaheshRavishankar // interchange vector. 478c694588fSMaheshRavishankar SmallVector<unsigned, 4> interchangeVector; 479c694588fSMaheshRavishankar interchangeVector.reserve(options.interchangeVector.size()); 480c694588fSMaheshRavishankar for (auto pos : options.interchangeVector) { 481c694588fSMaheshRavishankar auto it = loopIndexToRangeIndex.find(pos); 482c694588fSMaheshRavishankar if (it == loopIndexToRangeIndex.end()) 483c694588fSMaheshRavishankar continue; 484c694588fSMaheshRavishankar interchangeVector.push_back(it->second); 485c694588fSMaheshRavishankar } 48601c44185SNicolas Vasilache // Interchange vector is guaranteed to be a permutation, 48701c44185SNicolas Vasilache // `inversePermutation` must succeed. 488c694588fSMaheshRavishankar invPermutationMap = inversePermutation( 489c694588fSMaheshRavishankar AffineMap::getPermutationMap(interchangeVector, b.getContext())); 49001c44185SNicolas Vasilache assert(invPermutationMap); 4919072f1b5STobias Gysi SmallVector<int64_t> permutation(interchangeVector.begin(), 4929072f1b5STobias Gysi interchangeVector.end()); 4939072f1b5STobias Gysi applyPermutationToVector(loopRanges, permutation); 4949072f1b5STobias Gysi applyPermutationToVector(iteratorTypes, permutation); 495c694588fSMaheshRavishankar } 496b6281940SNicolas Vasilache 497f365e85cSMahesh Ravishankar // Handle distribution. Create a vector of the same size of loops that are to 498f365e85cSMahesh Ravishankar // be tiled. 499f365e85cSMahesh Ravishankar SmallVector<linalg::ProcInfo> procInfo; 500f365e85cSMahesh Ravishankar if (options.distribution) { 501f365e85cSMahesh Ravishankar procInfo.resize( 502f365e85cSMahesh Ravishankar iteratorTypes.size(), 503f365e85cSMahesh Ravishankar linalg::ProcInfo{nullptr, nullptr, linalg::DistributionMethod::None}); 50423bd2e96SMatthias Springer // Collect loop ranges of tiled loops, loops that are parallel. 505f365e85cSMahesh Ravishankar SmallVector<Range> parallelLoopRanges; 506e0568fa7SAdrian Kuegel for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) { 507f365e85cSMahesh Ravishankar if (!isParallelIterator(iteratorType.value())) 508f365e85cSMahesh Ravishankar break; 509f365e85cSMahesh Ravishankar parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); 510f365e85cSMahesh Ravishankar } 511f365e85cSMahesh Ravishankar auto returnedProcInfo = 512f365e85cSMahesh Ravishankar options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges); 513f365e85cSMahesh Ravishankar unsigned procIdIdx = 0; 514f365e85cSMahesh Ravishankar // Update the distribution information for the loops. 515e0568fa7SAdrian Kuegel for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) { 516f365e85cSMahesh Ravishankar if (!isParallelIterator(iteratorType.value())) 517f365e85cSMahesh Ravishankar break; 518f365e85cSMahesh Ravishankar procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++]; 519f365e85cSMahesh Ravishankar } 520f365e85cSMahesh Ravishankar } 521f365e85cSMahesh Ravishankar 522c694588fSMaheshRavishankar // 2. Create the tiled loops. 52358ddeba3SAlexander Belyaev LinalgOp res = op; 524a3adcba6SNicolas Vasilache SmallVector<Value, 4> ivs, tensorResults; 52516488dc3STobias Gysi auto tiledLoopBodyBuilder = 5264a661602SNicolas Vasilache [&](OpBuilder &builder, Location loc, ValueRange localIvs, 52716488dc3STobias Gysi ValueRange operandValuesToUse) -> scf::ValueVector { 528b4bc72afSAlex Zinenko ivs.assign(localIvs.begin(), localIvs.end()); 529f60bbb6cSJose Ignacio Gomez 530a3adcba6SNicolas Vasilache // When an `interchangeVector` is present, it has been applied to the 531a3adcba6SNicolas Vasilache // loop ranges and the iterator types. Apply its inverse to the 532a3adcba6SNicolas Vasilache // resulting loop `ivs` to match the op definition. 533a3adcba6SNicolas Vasilache SmallVector<Value, 4> interchangedIvs; 53423bd2e96SMatthias Springer if (!options.interchangeVector.empty()) { 53523bd2e96SMatthias Springer for (AffineExpr result : invPermutationMap.getResults()) 53623bd2e96SMatthias Springer interchangedIvs.push_back( 5371609f1c2Slong.chen ivs[cast<AffineDimExpr>(result).getPosition()]); 53823bd2e96SMatthias Springer } else { 539a3adcba6SNicolas Vasilache interchangedIvs.assign(ivs.begin(), ivs.end()); 54023bd2e96SMatthias Springer } 541f60bbb6cSJose Ignacio Gomez 54216488dc3STobias Gysi // Tile the `operandValuesToUse` that either match the `op` operands 54316488dc3STobias Gysi // themselves or the tile loop arguments forwarding them. 54416488dc3STobias Gysi assert(operandValuesToUse.size() == 545a7cccb9cSAlexander Belyaev static_cast<size_t>(op->getNumOperands()) && 54616488dc3STobias Gysi "expect the number of operands and inputs and outputs to match"); 54716488dc3STobias Gysi SmallVector<Value> valuesToTile = operandValuesToUse; 548e99fae89SAlex Zinenko SmallVector<OpFoldResult> sizeBounds = 54926821f75SAlex Zinenko makeComposedFoldedMultiResultAffineApply(b, loc, shapeSizesToLoopsMap, 55026821f75SAlex Zinenko allShapeSizes); 551e99fae89SAlex Zinenko SmallVector<Value> tiledOperands = makeTiledShapes( 552e99fae89SAlex Zinenko b, loc, op, valuesToTile, getAsOpFoldResult(interchangedIvs), tileSizes, 553e99fae89SAlex Zinenko sizeBounds, 554e99fae89SAlex Zinenko /*omitPartialTileCheck=*/false); 555a3adcba6SNicolas Vasilache 556ff6e5508SAlex Zinenko SmallVector<Type> resultTensorTypes = 557ff6e5508SAlex Zinenko getTensorOutputTypes(op, tiledOperands); 558f286af29SAlexander Belyaev res = clone(b, op, resultTensorTypes, tiledOperands); 559ff6e5508SAlex Zinenko tensorResults = 560ff6e5508SAlex Zinenko insertSlicesBack(builder, loc, op, tiledOperands, res->getResults()); 56158ddeba3SAlexander Belyaev return scf::ValueVector(tensorResults.begin(), tensorResults.end()); 56284a880e1SNicolas Vasilache }; 56384a880e1SNicolas Vasilache GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, 564f365e85cSMahesh Ravishankar tiledLoopBodyBuilder, procInfo); 565b6281940SNicolas Vasilache 566d69bccf1STobias Gysi // 3. Transform IndexOp results w.r.t. the tiling. 56758ddeba3SAlexander Belyaev transformIndexOps(b, res, ivs, loopIndexToRangeIndex); 568bae8a7a7SAlexander Belyaev 569c694588fSMaheshRavishankar // 4. Gather the newly created loops and return them with the new op. 5700da755dfSAlexander Belyaev SmallVector<Operation *, 8> loops; 571b6281940SNicolas Vasilache loops.reserve(ivs.size()); 572a5bfd32cSLei Zhang for (auto iv : ivs) { 5735550c821STres Popp if (isa<BlockArgument>(iv)) { 5745550c821STres Popp loops.push_back(cast<BlockArgument>(iv).getOwner()->getParentOp()); 575a5bfd32cSLei Zhang assert(loops.back() && "no owner found for induction variable!"); 57641d41200SMaheshRavishankar } else { 57741d41200SMaheshRavishankar // TODO: Instead of doing this, try to recover the ops used instead of the 57841d41200SMaheshRavishankar // loop. 57941d41200SMaheshRavishankar loops.push_back(nullptr); 58041d41200SMaheshRavishankar } 581a5bfd32cSLei Zhang } 582a3adcba6SNicolas Vasilache 583a3adcba6SNicolas Vasilache // 5. Get the tensor results from the outermost loop if available. Otherwise 584a3adcba6SNicolas Vasilache // use the previously captured `tensorResults`. 585a3adcba6SNicolas Vasilache Operation *outermostLoop = nullptr; 586a3adcba6SNicolas Vasilache for (Operation *loop : loops) 587a3adcba6SNicolas Vasilache if ((outermostLoop = loop)) 588a3adcba6SNicolas Vasilache break; 589a3adcba6SNicolas Vasilache 59058ddeba3SAlexander Belyaev return TiledLinalgOp{ 59158ddeba3SAlexander Belyaev res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; 592b6281940SNicolas Vasilache } 593b6281940SNicolas Vasilache 594eb2f946eSAlexander Belyaev FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall( 595eb2f946eSAlexander Belyaev RewriterBase &b, PartialReductionOpInterface op, 596eb2f946eSAlexander Belyaev ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes, 59722426110SRamkumar Ramachandra std::optional<ArrayAttr> mapping) { 598f7fda6baSThomas Raoux Location loc = op.getLoc(); 599f7fda6baSThomas Raoux OpBuilder::InsertionGuard g(b); 60006ca5c81SNicolas Vasilache 601f7fda6baSThomas Raoux // Ops implementing PartialReductionOpInterface are expected to implement 602f7fda6baSThomas Raoux // TilingInterface. 60306ca5c81SNicolas Vasilache // TODO: proper core mechanism to tie interfaces together. 604f7fda6baSThomas Raoux auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); 60506ca5c81SNicolas Vasilache 60606ca5c81SNicolas Vasilache // Ops implementing PartialReductionOpInterface are not necessarily expected 60706ca5c81SNicolas Vasilache // to implement TilingInterface.. This cast is unsafe atm. 60806ca5c81SNicolas Vasilache // TODO: proper core mechanism to tie interfaces together. 60906ca5c81SNicolas Vasilache // TODO: this function requires a pair of interfaces .. 61006ca5c81SNicolas Vasilache auto destinationStyleOp = 61106ca5c81SNicolas Vasilache dyn_cast<DestinationStyleOpInterface>(op.getOperation()); 61206ca5c81SNicolas Vasilache if (!destinationStyleOp) 61306ca5c81SNicolas Vasilache return b.notifyMatchFailure(op, "not a destination style op"); 61406ca5c81SNicolas Vasilache 61506ca5c81SNicolas Vasilache // Actually this only work for Linalg ops atm. 61606ca5c81SNicolas Vasilache auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation()); 61706ca5c81SNicolas Vasilache if (!linalgOp) 61806ca5c81SNicolas Vasilache return b.notifyMatchFailure(op, "not a linalg op"); 61906ca5c81SNicolas Vasilache 620f7fda6baSThomas Raoux SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); 621f7fda6baSThomas Raoux if (op->getNumResults() != 1) 622f7fda6baSThomas Raoux return b.notifyMatchFailure( 623f7fda6baSThomas Raoux op, "don't support ops with multiple results for now"); 62406ca5c81SNicolas Vasilache 625f7fda6baSThomas Raoux SmallVector<utils::IteratorType> iterators = 626f7fda6baSThomas Raoux tilingInterfaceOp.getLoopIteratorTypes(); 627f7fda6baSThomas Raoux SmallVector<unsigned> redDims; 62806ca5c81SNicolas Vasilache linalgOp.getReductionDims(redDims); 629f7fda6baSThomas Raoux if (redDims.size() != 1) 630f7fda6baSThomas Raoux return b.notifyMatchFailure( 631f7fda6baSThomas Raoux op, "only support ops with one reduction dimension."); 632f7fda6baSThomas Raoux if (!tileSizes.empty() && tileSizes.size() != numThreads.size()) 633f7fda6baSThomas Raoux return b.notifyMatchFailure(op, "if tile sizes are present it must have as " 634f7fda6baSThomas Raoux "many elements as number of threads"); 635f7fda6baSThomas Raoux int reductionDim = static_cast<int>(redDims.front()); 63606ca5c81SNicolas Vasilache 637faac8989SAlex Zinenko if (redDims.front() >= numThreads.size()) 638faac8989SAlex Zinenko return b.notifyMatchFailure( 639faac8989SAlex Zinenko op, "reduction dimension must be mapped to threads"); 640faac8989SAlex Zinenko 64106ca5c81SNicolas Vasilache // 1. Create the inital tensor value. 6429329b20dSKunwar Grover FailureOr<SmallVector<Value>> maybeInitTensors = 643f7fda6baSThomas Raoux op.generateInitialTensorForPartialReduction(b, loc, numThreads, 644f7fda6baSThomas Raoux reductionDim); 6459329b20dSKunwar Grover if (failed(maybeInitTensors)) 6469329b20dSKunwar Grover return b.notifyMatchFailure( 6479329b20dSKunwar Grover op, "Failed to create inital tensors for partial reduction"); 6489329b20dSKunwar Grover SmallVector<Value> &initTensors = maybeInitTensors.value(); 649f7fda6baSThomas Raoux 650f7fda6baSThomas Raoux // Gather destination tensors. 651f7fda6baSThomas Raoux SmallVector<Value> dest; 652f7fda6baSThomas Raoux if (failed(tensor::getOrCreateDestinations(b, loc, op, dest))) 653f7fda6baSThomas Raoux return b.notifyMatchFailure(op, "failed to get destination tensors"); 654f7fda6baSThomas Raoux 655f7fda6baSThomas Raoux Operation *tiledOp = nullptr; 656f7fda6baSThomas Raoux 657f4d75863SJakub Kuderski SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector( 658f4d75863SJakub Kuderski numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); }); 659f7fda6baSThomas Raoux SmallVector<Value> materializedNonZeroNumThreads = 660c888a0ceSNicolas Vasilache getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads); 661f7fda6baSThomas Raoux 662eb2f946eSAlexander Belyaev // 2. Create the ForallOp with an empty region. 663eb2f946eSAlexander Belyaev scf::ForallOp forallOp = b.create<scf::ForallOp>( 6649329b20dSKunwar Grover loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors, 6659329b20dSKunwar Grover mapping); 666f7fda6baSThomas Raoux 66706ca5c81SNicolas Vasilache // 3. Calculate the tile offsets and sizes for the subsequent loop that will 668eb2f946eSAlexander Belyaev // be nested under `forallOp`. 669f7fda6baSThomas Raoux SmallVector<OpFoldResult> tiledOffsets, tiledSizes; 670eb2f946eSAlexander Belyaev calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain, 671f7fda6baSThomas Raoux /*omitTileOffsetBoundsCheck =*/false, 672eb2f946eSAlexander Belyaev /*nominalTileSizes=*/std::nullopt, tiledOffsets, 673eb2f946eSAlexander Belyaev tiledSizes); 674f7fda6baSThomas Raoux 6759329b20dSKunwar Grover // 4b. Clone the tileable op and update its destination operands to use the 676eb2f946eSAlexander Belyaev // output bbArgs of the ForallOp. 6771da04b04SMahesh Ravishankar SmallVector<Value> tilingResults; 67876ead96cSMaheshRavishankar ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs(); 67906ca5c81SNicolas Vasilache { 680eb2f946eSAlexander Belyaev // 4.a. RAII guard, inserting within forallOp, before terminator. 68106ca5c81SNicolas Vasilache OpBuilder::InsertionGuard g(b); 682eb2f946eSAlexander Belyaev b.setInsertionPoint(forallOp.getTerminator()); 68306ca5c81SNicolas Vasilache 68406ca5c81SNicolas Vasilache SmallVector<Value> tiledDpsInitOperands; 6850b2197b0SMatthias Springer for (Value initOperand : destinationStyleOp.getDpsInits()) { 6860b2197b0SMatthias Springer auto *it = llvm::find(dest, initOperand); 687f7fda6baSThomas Raoux assert(it != dest.end() && "dest operand not found in dest"); 688f7fda6baSThomas Raoux unsigned destNum = std::distance(dest.begin(), it); 689f7fda6baSThomas Raoux SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1)); 69006ca5c81SNicolas Vasilache SmallVector<OpFoldResult> outOffsets(numThreads.size(), 69106ca5c81SNicolas Vasilache b.getIndexAttr(0)); 692f7fda6baSThomas Raoux SmallVector<OpFoldResult> sizes = tiledSizes; 693f7fda6baSThomas Raoux sizes[reductionDim] = b.getIndexAttr(1); 6946b4c1228Ssrcarroll outOffsets[reductionDim] = forallOp.getInductionVars()[0]; 695f7fda6baSThomas Raoux // TODO: use SubsetExtractOpInterface once it is available. 69606ca5c81SNicolas Vasilache tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>( 6970b2197b0SMatthias Springer loc, cast<RankedTensorType>(initOperand.getType()), 69806ca5c81SNicolas Vasilache destBbArgs[destNum], outOffsets, sizes, strides)); 699f7fda6baSThomas Raoux } 70006ca5c81SNicolas Vasilache 70106ca5c81SNicolas Vasilache // 4.b. Clone the op and update init operands. 7024d67b278SJeff Niu // We cannot use a IRMapping here because it can replace 70306ca5c81SNicolas Vasilache // different OpOperands with the same value. 70406ca5c81SNicolas Vasilache Operation *clonedOp = b.clone(*op.getOperation()); 7055fcf907bSMatthias Springer b.modifyOpInPlace(clonedOp, [&]() { 70606ca5c81SNicolas Vasilache for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal( 7070b2197b0SMatthias Springer cast<DestinationStyleOpInterface>(clonedOp).getDpsInitsMutable(), 70806ca5c81SNicolas Vasilache tiledDpsInitOperands)) { 7090b2197b0SMatthias Springer initOperandPtr.set(tiledInitValue); 71006ca5c81SNicolas Vasilache } 71106ca5c81SNicolas Vasilache }); 712f7fda6baSThomas Raoux 713f7fda6baSThomas Raoux // 5. Tile the cloned op and delete the clone. 714f7fda6baSThomas Raoux if (tileSizes.empty()) { 715809e3d8cSMahesh Ravishankar FailureOr<TilingResult> tilingResult = 71606ca5c81SNicolas Vasilache cast<TilingInterface>(clonedOp).getTiledImplementation( 71706ca5c81SNicolas Vasilache b, tiledOffsets, tiledSizes); 718242b3091SNicolas Vasilache if (failed(tilingResult)) 719242b3091SNicolas Vasilache return clonedOp->emitError("Failed to tile op: "); 720242b3091SNicolas Vasilache if (tilingResult->tiledOps.size() != 1) { 721242b3091SNicolas Vasilache return clonedOp->emitError("expected a single produced tiled op, got ") 722242b3091SNicolas Vasilache << tilingResult->tiledOps.size(); 723242b3091SNicolas Vasilache } 724809e3d8cSMahesh Ravishankar tiledOp = tilingResult->tiledOps.front(); 725809e3d8cSMahesh Ravishankar tilingResults = tilingResult->tiledValues; 726f7fda6baSThomas Raoux } else { 727f7fda6baSThomas Raoux LinalgTilingOptions options; 72806ca5c81SNicolas Vasilache FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>( 72906ca5c81SNicolas Vasilache b, cast<LinalgOp>(clonedOp), tileSizes, options); 73006ca5c81SNicolas Vasilache if (failed(maybeTiled)) 73106ca5c81SNicolas Vasilache return b.notifyMatchFailure(op, "failed tileLinalgOpImpl"); 73206ca5c81SNicolas Vasilache 733eb2f946eSAlexander Belyaev SmallVector<Value> ids = forallOp.getInductionVars(); 73406ca5c81SNicolas Vasilache mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids, 735f7fda6baSThomas Raoux materializedNonZeroNumThreads); 736242b3091SNicolas Vasilache if (maybeTiled->loops.size() != 1) { 737242b3091SNicolas Vasilache return clonedOp->emitError("expected a single produced loop"); 738242b3091SNicolas Vasilache } 73906ca5c81SNicolas Vasilache tiledOp = maybeTiled->op; 74006ca5c81SNicolas Vasilache tilingResults = maybeTiled->loops.front()->getResults(); 741f7fda6baSThomas Raoux } 74206ca5c81SNicolas Vasilache 743f7fda6baSThomas Raoux b.eraseOp(clonedOp); 74406ca5c81SNicolas Vasilache } 745f7fda6baSThomas Raoux 746f7fda6baSThomas Raoux // 6. Insert the partial reductions back into a new tensor. 74706ca5c81SNicolas Vasilache for (auto [index, result, bbArg] : llvm::zip( 74806ca5c81SNicolas Vasilache llvm::seq<unsigned>(0, dest.size()), tilingResults, destBbArgs)) { 74906ca5c81SNicolas Vasilache // 6.a. Partial subset information is inserted just before the terminator. 75006ca5c81SNicolas Vasilache OpBuilder::InsertionGuard g(b); 751eb2f946eSAlexander Belyaev b.setInsertionPoint(forallOp.getTerminator()); 75206ca5c81SNicolas Vasilache 753f7fda6baSThomas Raoux SmallVector<OpFoldResult> resultOffsets, resultSizes; 754f7fda6baSThomas Raoux if (failed(tilingInterfaceOp.getResultTilePosition( 755f7fda6baSThomas Raoux b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes))) 756f7fda6baSThomas Raoux return op->emitOpError("output offsets couldn't be calculated"); 757f7fda6baSThomas Raoux SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank; 758f7fda6baSThomas Raoux int64_t offIdx = 0; 759f7fda6baSThomas Raoux int64_t sizeIdx = 0; 760f7fda6baSThomas Raoux for (int64_t i = 0, e = numThreads.size(); i < e; ++i) { 761f7fda6baSThomas Raoux if (i == reductionDim) { 7626b4c1228Ssrcarroll resultOffsetsRank.push_back(forallOp.getInductionVars()[0]); 763f7fda6baSThomas Raoux resultSizesRank.push_back(b.getIndexAttr(1)); 764f7fda6baSThomas Raoux continue; 765f7fda6baSThomas Raoux } 766f7fda6baSThomas Raoux resultOffsetsRank.push_back(resultOffsets[offIdx++]); 767f7fda6baSThomas Raoux resultSizesRank.push_back(resultSizes[sizeIdx++]); 768f7fda6baSThomas Raoux } 769f7fda6baSThomas Raoux SmallVector<OpFoldResult> strides(resultSizesRank.size(), 770f7fda6baSThomas Raoux b.getIndexAttr(1)); 77106ca5c81SNicolas Vasilache 77206ca5c81SNicolas Vasilache // 6.b. Parallel insertions are inserted at the end of the combining 77306ca5c81SNicolas Vasilache // terminator. 774eb2f946eSAlexander Belyaev b.setInsertionPointToEnd(forallOp.getTerminator().getBody()); 775f7fda6baSThomas Raoux b.create<tensor::ParallelInsertSliceOp>( 776f7fda6baSThomas Raoux loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides); 777f7fda6baSThomas Raoux } 77806ca5c81SNicolas Vasilache 779f7fda6baSThomas Raoux // 7. Merge the partial reductions. 780eb2f946eSAlexander Belyaev b.setInsertionPointAfter(forallOp); 781b99d0b34SMaheshRavishankar FailureOr<MergeResult> mergeResult = 782eb2f946eSAlexander Belyaev op.mergeReductions(b, loc, forallOp->getResults(), reductionDim); 783b99d0b34SMaheshRavishankar if (failed(mergeResult)) { 784b99d0b34SMaheshRavishankar return failure(); 785b99d0b34SMaheshRavishankar } 786b99d0b34SMaheshRavishankar b.replaceOp(op, mergeResult->replacements); 78706ca5c81SNicolas Vasilache 78806ca5c81SNicolas Vasilache // 8. Return. 789eb2f946eSAlexander Belyaev ForallReductionTilingResult results; 7909329b20dSKunwar Grover results.initialValues = initTensors; 791eb2f946eSAlexander Belyaev results.loops = forallOp; 792b99d0b34SMaheshRavishankar results.parallelTiledOps.push_back(tiledOp); 793b99d0b34SMaheshRavishankar results.mergeOps.append(mergeResult->mergeOps); 794f7fda6baSThomas Raoux return results; 795f7fda6baSThomas Raoux } 796f7fda6baSThomas Raoux 797c694588fSMaheshRavishankar template <typename LoopTy> 798489fec27SNicolas Vasilache FailureOr<TiledLinalgOp> static tileLinalgOpImpl( 7994a661602SNicolas Vasilache RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) { 800c694588fSMaheshRavishankar OpBuilder::InsertionGuard g(b); 801c694588fSMaheshRavishankar b.setInsertionPoint(op); 802c694588fSMaheshRavishankar 8034643fd27SNicolas Vasilache if (!options.tileSizeComputationFunction) 804489fec27SNicolas Vasilache return failure(); 8054643fd27SNicolas Vasilache 806c694588fSMaheshRavishankar // Enforce the convention that "tiling by zero" skips tiling a particular 807c694588fSMaheshRavishankar // dimension. This convention is significantly simpler to handle instead of 808c694588fSMaheshRavishankar // adjusting affine maps to account for missing dimensions. 809c694588fSMaheshRavishankar auto nLoops = op.getNumLoops(); 810e99fae89SAlex Zinenko SmallVector<OpFoldResult> tileSizeVector = 811e99fae89SAlex Zinenko getAsOpFoldResult(options.tileSizeComputationFunction(b, op)); 812c694588fSMaheshRavishankar if (tileSizeVector.size() < nLoops) { 813e99fae89SAlex Zinenko tileSizeVector.append(nLoops - tileSizeVector.size(), b.getIndexAttr(0)); 814c694588fSMaheshRavishankar } 815c694588fSMaheshRavishankar 816c694588fSMaheshRavishankar return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options); 817c694588fSMaheshRavishankar } 818c694588fSMaheshRavishankar 819489fec27SNicolas Vasilache FailureOr<TiledLinalgOp> 8204a661602SNicolas Vasilache mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op, 821004a3d4fSNicolas Vasilache const LinalgTilingOptions &options) { 822c694588fSMaheshRavishankar switch (options.loopType) { 823c694588fSMaheshRavishankar case LinalgTilingLoopType::Loops: 824004a3d4fSNicolas Vasilache return tileLinalgOpImpl<scf::ForOp>(b, op, options); 825c694588fSMaheshRavishankar case LinalgTilingLoopType::ParallelLoops: 826004a3d4fSNicolas Vasilache return tileLinalgOpImpl<scf::ParallelOp>(b, op, options); 827c694588fSMaheshRavishankar default:; 828c694588fSMaheshRavishankar } 829489fec27SNicolas Vasilache return failure(); 8300da755dfSAlexander Belyaev } 8310da755dfSAlexander Belyaev 832004a3d4fSNicolas Vasilache namespace { 833004a3d4fSNicolas Vasilache /// Helper classes for type list expansion. 834004a3d4fSNicolas Vasilache template <typename... OpTypes> 835004a3d4fSNicolas Vasilache class CanonicalizationPatternList; 836004a3d4fSNicolas Vasilache 837004a3d4fSNicolas Vasilache template <> 838004a3d4fSNicolas Vasilache class CanonicalizationPatternList<> { 839004a3d4fSNicolas Vasilache public: 840dc4e913bSChris Lattner static void insert(RewritePatternSet &patterns) {} 841004a3d4fSNicolas Vasilache }; 842004a3d4fSNicolas Vasilache 843004a3d4fSNicolas Vasilache template <typename OpTy, typename... OpTypes> 844004a3d4fSNicolas Vasilache class CanonicalizationPatternList<OpTy, OpTypes...> { 845004a3d4fSNicolas Vasilache public: 846dc4e913bSChris Lattner static void insert(RewritePatternSet &patterns) { 8473a506b31SChris Lattner OpTy::getCanonicalizationPatterns(patterns, patterns.getContext()); 8483a506b31SChris Lattner CanonicalizationPatternList<OpTypes...>::insert(patterns); 849004a3d4fSNicolas Vasilache } 850004a3d4fSNicolas Vasilache }; 851004a3d4fSNicolas Vasilache } // namespace 852004a3d4fSNicolas Vasilache 853dc4e913bSChris Lattner RewritePatternSet 854004a3d4fSNicolas Vasilache mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) { 855dc4e913bSChris Lattner RewritePatternSet patterns(ctx); 8563a506b31SChris Lattner populateLinalgTilingCanonicalizationPatterns(patterns); 8572c66b6ecSNicolas Vasilache return patterns; 8582c66b6ecSNicolas Vasilache } 8592c66b6ecSNicolas Vasilache 8602c66b6ecSNicolas Vasilache void mlir::linalg::populateLinalgTilingCanonicalizationPatterns( 861dc4e913bSChris Lattner RewritePatternSet &patterns) { 8623a506b31SChris Lattner auto *ctx = patterns.getContext(); 8634c48f016SMatthias Springer affine::AffineApplyOp::getCanonicalizationPatterns(patterns, ctx); 8644c48f016SMatthias Springer affine::AffineForOp::getCanonicalizationPatterns(patterns, ctx); 8654c48f016SMatthias Springer affine::AffineMinOp::getCanonicalizationPatterns(patterns, ctx); 8664c48f016SMatthias Springer affine::AffineMaxOp::getCanonicalizationPatterns(patterns, ctx); 867a54f4eaeSMogball arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); 868a3f42594SLei Zhang 869a3f42594SLei Zhang memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); 870a3f42594SLei Zhang memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); 871a3f42594SLei Zhang 872004a3d4fSNicolas Vasilache scf::ForOp::getCanonicalizationPatterns(patterns, ctx); 873004a3d4fSNicolas Vasilache scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); 874a3f42594SLei Zhang 875a3f42594SLei Zhang tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); 87681ca5aa4SMatthias Springer tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx); 877060208b4SMatthias Springer tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); 878a0e02018SMatthias Springer tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); 879fd0c6f53SAlexander Belyaev tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); 88066e27082SRiver Riddle ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns); 881a3f42594SLei Zhang 882004a3d4fSNicolas Vasilache CanonicalizationPatternList< 883004a3d4fSNicolas Vasilache #define GET_OP_LIST 884004a3d4fSNicolas Vasilache #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 8853a506b31SChris Lattner >::insert(patterns); 886baecae83SAlexander Belyaev } 887