1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for the Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Utils/Utils.h" 14 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 19 #include "mlir/Dialect/Affine/LoopUtils.h" 20 #include "mlir/Dialect/Arith/IR/Arith.h" 21 #include "mlir/Dialect/Arith/Utils/Utils.h" 22 #include "mlir/Dialect/Func/IR/FuncOps.h" 23 #include "mlir/Dialect/Linalg/IR/Linalg.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "mlir/Dialect/Tensor/IR/Tensor.h" 27 #include "mlir/Dialect/Tensor/Utils/Utils.h" 28 #include "mlir/Dialect/Utils/IndexingUtils.h" 29 #include "mlir/Dialect/Utils/StaticValueUtils.h" 30 #include "mlir/IR/AffineExpr.h" 31 #include "mlir/IR/AffineExprVisitor.h" 32 #include "mlir/IR/AffineMap.h" 33 #include "mlir/IR/Matchers.h" 34 #include "mlir/IR/OpImplementation.h" 35 #include "mlir/Pass/Pass.h" 36 #include "llvm/ADT/TypeSwitch.h" 37 #include "llvm/Support/Debug.h" 38 #include <optional> 39 40 #define DEBUG_TYPE "linalg-utils" 41 42 using namespace mlir; 43 using namespace presburger; 44 using namespace mlir::affine; 45 using namespace mlir::linalg; 46 using namespace mlir::scf; 47 48 namespace { 49 50 // Helper visitor to determine whether an AffineExpr is tiled. 51 // This is achieved by traversing every AffineDimExpr with position `pos` and 52 // checking whether the corresponding `tileSizes[pos]` is non-zero. 53 // This also enforces only positive coefficients occur in multiplications. 54 // 55 // Example: 56 // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] 57 // 58 struct TileCheck : public AffineExprVisitor<TileCheck> { 59 TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {} 60 61 void visitDimExpr(AffineDimExpr expr) { 62 isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]); 63 } 64 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { 65 visit(expr.getLHS()); 66 visit(expr.getRHS()); 67 if (expr.getKind() == mlir::AffineExprKind::Mul) 68 assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 && 69 "nonpositive multiplying coefficient"); 70 } 71 bool isTiled = false; 72 ArrayRef<OpFoldResult> tileSizes; 73 }; 74 75 } // namespace 76 77 static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) { 78 if (!expr) 79 return false; 80 TileCheck t(tileSizes); 81 t.visit(expr); 82 return t.isTiled; 83 } 84 85 // Checks whether the `map varies with respect to a non-zero `tileSize`. 86 static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) { 87 if (!map) 88 return false; 89 for (unsigned r = 0; r < map.getNumResults(); ++r) 90 if (isTiled(map.getResult(r), tileSizes)) 91 return true; 92 return false; 93 } 94 95 std::optional<RegionMatcher::BinaryOpKind> 96 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { 97 auto ®ion = op.getRegion(); 98 if (!llvm::hasSingleElement(region)) 99 return std::nullopt; 100 101 Block &block = region.front(); 102 if (block.getNumArguments() != 2 || 103 !block.getArgument(0).getType().isSignlessIntOrFloat() || 104 !block.getArgument(1).getType().isSignlessIntOrFloat()) 105 return std::nullopt; 106 107 auto &ops = block.getOperations(); 108 if (!llvm::hasSingleElement(block.without_terminator())) 109 return std::nullopt; 110 111 using mlir::matchers::m_Val; 112 auto a = m_Val(block.getArgument(0)); 113 auto b = m_Val(block.getArgument(1)); 114 115 auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b)); 116 if (addPattern.match(&ops.back())) 117 return BinaryOpKind::IAdd; 118 119 return std::nullopt; 120 } 121 122 /// Explicit instantiation of loop nest generator for different loop types. 123 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>; 124 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>; 125 template struct mlir::linalg::GenerateLoopNest<AffineForOp>; 126 127 /// Given a list of subview ranges, extract individual values for lower, upper 128 /// bounds and steps and put them into the corresponding vectors. 129 static void unpackRanges(OpBuilder &builder, Location loc, 130 ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs, 131 SmallVectorImpl<Value> &ubs, 132 SmallVectorImpl<Value> &steps) { 133 for (Range range : ranges) { 134 lbs.emplace_back( 135 getValueOrCreateConstantIndexOp(builder, loc, range.offset)); 136 ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size)); 137 steps.emplace_back( 138 getValueOrCreateConstantIndexOp(builder, loc, range.stride)); 139 } 140 } 141 142 //===----------------------------------------------------------------------===// 143 // General utilities 144 //===----------------------------------------------------------------------===// 145 146 namespace mlir { 147 namespace linalg { 148 149 bool allIndexingsAreProjectedPermutation(LinalgOp op) { 150 return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) { 151 return m.isProjectedPermutation(/*allowZeroInResults=*/true); 152 }); 153 } 154 155 bool hasOnlyScalarElementwiseOp(Region &r) { 156 if (!llvm::hasSingleElement(r)) 157 return false; 158 for (Operation &op : r.front()) { 159 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp, 160 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) || 161 OpTrait::hasElementwiseMappableTraits(&op)) || 162 llvm::any_of(op.getResultTypes(), 163 [](Type type) { return !type.isIntOrIndexOrFloat(); })) 164 return false; 165 } 166 return true; 167 } 168 169 bool isElementwise(LinalgOp op) { 170 if (op.getNumLoops() != op.getNumParallelLoops()) 171 return false; 172 173 if (!allIndexingsAreProjectedPermutation(op)) 174 return false; 175 176 // TODO: relax the restrictions on indexing map. 177 for (OpOperand &opOperand : op.getDpsInitsMutable()) { 178 if (!op.getMatchingIndexingMap(&opOperand).isPermutation()) 179 return false; 180 } 181 return hasOnlyScalarElementwiseOp(op->getRegion(0)); 182 } 183 184 bool isParallelIterator(utils::IteratorType iteratorType) { 185 return iteratorType == utils::IteratorType::parallel; 186 } 187 188 bool isReductionIterator(utils::IteratorType iteratorType) { 189 return iteratorType == utils::IteratorType::reduction; 190 } 191 192 Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, 193 Value source, Value pad, bool nofold) { 194 // Exit if `source` is not defined by an ExtractSliceOp. 195 auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>(); 196 if (!sliceOp) 197 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 198 199 // Search the `source` use-def chain for padded LinalgOps. 200 Value current = sliceOp.getSource(); 201 while (current) { 202 auto linalgOp = current.getDefiningOp<LinalgOp>(); 203 if (!linalgOp) 204 break; 205 OpResult opResult = cast<OpResult>(current); 206 current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get(); 207 } 208 auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr; 209 210 // Exit if the search fails to match a tensor::PadOp at the end of the matched 211 // LinalgOp sequence. 212 if (!padOp) 213 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 214 215 // Exit if the padded result type does not match. 216 if (sliceOp.getSource().getType() != type) 217 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 218 219 // Exit if the LinalgOps are not high padded. 220 if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) { 221 return getConstantIntValue(ofr) != static_cast<int64_t>(0); 222 })) 223 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 224 225 // Exit if `padOpSliceOp`, which defines the slice used by 226 // `padOp`, is rank-reducing. 227 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>(); 228 if (!padOpSliceOp || 229 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size()) 230 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 231 232 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size 233 // of the slice padded by `padOp`. 234 if (llvm::any_of( 235 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()), 236 [](std::tuple<OpFoldResult, OpFoldResult> it) { 237 return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it)); 238 })) 239 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 240 241 // Exit if the padding values do not match. 242 Attribute padOpPadAttr, padAttr; 243 Value padOpPad = padOp.getConstantPaddingValue(); 244 if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) || 245 !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr) 246 return tensor::createPadHighOp(type, source, pad, nofold, loc, b); 247 248 // Return the padded result if the padding values and sizes match. 249 return sliceOp.getSource(); 250 } 251 252 GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, 253 Value outputTensor, 254 ArrayRef<int64_t> transposeVector) { 255 auto resultTensorType = cast<RankedTensorType>(outputTensor.getType()); 256 Type elementType = resultTensorType.getElementType(); 257 258 assert(isPermutationVector(transposeVector) && 259 "expect transpose vector to be a permutation"); 260 assert(transposeVector.size() == 261 static_cast<size_t>(resultTensorType.getRank()) && 262 "expect transpose vector size to match result tensor rank"); 263 264 // Compute the transpose and the indentity indexing maps. 265 SmallVector<AffineMap> indexingMaps = { 266 inversePermutation(AffineMap::getPermutationMap( 267 SmallVector<unsigned>(transposeVector), b.getContext())), 268 AffineMap::getMultiDimIdentityMap(transposeVector.size(), 269 b.getContext())}; 270 SmallVector<utils::IteratorType> iteratorTypes(transposeVector.size(), 271 utils::IteratorType::parallel); 272 273 // Create a GenericOp to transpose `inputTensor` into `outputTensor`. 274 auto transposeOp = 275 b.create<GenericOp>(loc, resultTensorType, inputTensor, outputTensor, 276 indexingMaps, iteratorTypes); 277 278 // Create the body of the transpose operation. 279 OpBuilder::InsertionGuard g(b); 280 Region &body = transposeOp.getRegion(); 281 Block *bodyBlock = b.createBlock(&body, /*insertPt=*/{}, 282 {elementType, elementType}, {loc, loc}); 283 b.create<YieldOp>(loc, bodyBlock->getArgument(0)); 284 return transposeOp; 285 } 286 287 GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { 288 auto memrefTypeTo = cast<MemRefType>(to.getType()); 289 #ifndef NDEBUG 290 auto memrefTypeFrom = cast<MemRefType>(from.getType()); 291 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && 292 "`from` and `to` memref must have the same rank"); 293 #endif // NDEBUG 294 295 AffineMap id = 296 AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); 297 SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(), 298 utils::IteratorType::parallel); 299 return b.create<linalg::GenericOp>( 300 loc, 301 /*inputs=*/from, 302 /*outputs=*/to, 303 /*indexingMaps=*/llvm::ArrayRef({id, id}), 304 /*iteratorTypes=*/iteratorTypes, 305 [](OpBuilder &b, Location loc, ValueRange args) { 306 b.create<linalg::YieldOp>(loc, args.front()); 307 }); 308 } 309 310 /// Specialization to build an scf "for" nest. 311 template <> 312 void GenerateLoopNest<scf::ForOp>::doit( 313 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 314 ArrayRef<utils::IteratorType> iteratorTypes, 315 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 316 ValueRange)> 317 bodyBuilderFn, 318 ArrayRef<linalg::ProcInfo> procInfo) { 319 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && 320 "expected as many entries for proc info as number of loops, even if " 321 "they are null entries"); 322 SmallVector<Value> iterArgInitValues; 323 if (!linalgOp.hasPureBufferSemantics()) 324 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits()); 325 SmallVector<Value, 4> lbs, ubs, steps; 326 unpackRanges(b, loc, loopRanges, lbs, ubs, steps); 327 LoopNest loopNest = mlir::scf::buildLoopNest( 328 b, loc, lbs, ubs, steps, iterArgInitValues, 329 [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { 330 assert(iterArgs.size() == iterArgInitValues.size() && 331 "expect the number of output tensors and iter args to match"); 332 SmallVector<Value> operandValuesToUse = linalgOp->getOperands(); 333 if (!iterArgs.empty()) { 334 operandValuesToUse = linalgOp.getDpsInputs(); 335 operandValuesToUse.append(iterArgs.begin(), iterArgs.end()); 336 } 337 return bodyBuilderFn(b, loc, ivs, operandValuesToUse); 338 }); 339 340 if (loopNest.loops.empty() || procInfo.empty()) 341 return; 342 343 // Filter out scf.for loops that were created out of parallel dimensions. 344 for (const auto &loop : llvm::enumerate(loopNest.loops)) { 345 if (procInfo[loop.index()].distributionMethod == 346 DistributionMethod::Cyclic) { 347 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId, 348 procInfo[loop.index()].nprocs); 349 } 350 } 351 } 352 353 /// Specialization to build affine "for" nest. 354 template <> 355 void GenerateLoopNest<AffineForOp>::doit( 356 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 357 ArrayRef<utils::IteratorType> iteratorTypes, 358 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 359 ValueRange)> 360 bodyBuilderFn, 361 ArrayRef<linalg::ProcInfo> /*procInfo*/) { 362 SmallVector<Value> iterArgInitValues; 363 if (!linalgOp.hasPureBufferSemantics()) 364 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits()); 365 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); 366 SmallVector<Value, 4> lbs, ubs, steps; 367 unpackRanges(b, loc, loopRanges, lbs, ubs, steps); 368 369 // Affine loops require constant steps. 370 SmallVector<int64_t, 4> constantSteps; 371 constantSteps.reserve(steps.size()); 372 for (Value v : steps) { 373 auto constVal = getConstantIntValue(v); 374 assert(constVal.has_value() && "Affine loops require constant steps"); 375 constantSteps.push_back(constVal.value()); 376 } 377 378 affine::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps, 379 [&](OpBuilder &b, Location loc, ValueRange ivs) { 380 bodyBuilderFn(b, loc, ivs, 381 linalgOp->getOperands()); 382 }); 383 } 384 385 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. 386 void updateBoundsForCyclicDistribution(OpBuilder &b, Location loc, Value procId, 387 Value nprocs, Value &lb, Value &ub, 388 Value &step) { 389 AffineExpr d0, d1; 390 bindDims(b.getContext(), d0, d1); 391 AffineExpr s0 = getAffineSymbolExpr(0, b.getContext()); 392 lb = 393 affine::makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step}); 394 step = affine::makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step}); 395 } 396 397 /// Generates a loop nest consisting of scf.parallel and scf.for, depending 398 /// on the `iteratorTypes.` Consecutive parallel loops create a single 399 /// scf.parallel operation; each sequential loop creates a new scf.for 400 /// operation. The body of the innermost loop is populated by 401 /// `bodyBuilderFn` that accepts a range of induction variables for all 402 /// loops. `ivStorage` is used to store the partial list of induction 403 /// variables. 404 // TODO: this function can be made iterative instead. However, it 405 // will have at most as many recursive calls as nested loops, which rarely 406 // exceeds 10. 407 static void generateParallelLoopNest( 408 OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, 409 ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes, 410 ArrayRef<linalg::ProcInfo> procInfo, 411 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, 412 SmallVectorImpl<Value> &ivStorage) { 413 assert(lbs.size() == ubs.size()); 414 assert(lbs.size() == steps.size()); 415 assert(lbs.size() == iteratorTypes.size()); 416 assert(procInfo.empty() || (lbs.size() == procInfo.size())); 417 418 // If there are no (more) loops to be generated, generate the body and be 419 // done with it. 420 if (iteratorTypes.empty()) { 421 bodyBuilderFn(b, loc, ivStorage); 422 return; 423 } 424 425 // If there are no outer parallel loops, generate one sequential loop and 426 // recurse. 427 if (!isParallelIterator(iteratorTypes.front())) { 428 LoopNest singleLoop = buildLoopNest( 429 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(), 430 [&](OpBuilder &b, Location loc, ValueRange ivs) { 431 ivStorage.append(ivs.begin(), ivs.end()); 432 generateParallelLoopNest( 433 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(), 434 iteratorTypes.drop_front(), 435 procInfo.empty() ? procInfo : procInfo.drop_front(), 436 bodyBuilderFn, ivStorage); 437 }); 438 return; 439 } 440 441 unsigned nLoops = iteratorTypes.size(); 442 unsigned numProcessed = 0; 443 DistributionMethod distributionMethod = DistributionMethod::None; 444 if (procInfo.empty()) { 445 numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size(); 446 } else { 447 distributionMethod = procInfo.front().distributionMethod; 448 numProcessed = 449 nLoops - procInfo 450 .drop_while([&](linalg::ProcInfo p) { 451 return p.distributionMethod == distributionMethod; 452 }) 453 .size(); 454 } 455 456 auto remainderProcInfo = 457 procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed); 458 switch (distributionMethod) { 459 case DistributionMethod::None: { 460 // Generate a single parallel loop-nest operation for all outermost 461 // parallel loops and recurse. 462 b.create<scf::ParallelOp>( 463 loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), 464 steps.take_front(numProcessed), 465 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { 466 ivStorage.append(localIvs.begin(), localIvs.end()); 467 generateParallelLoopNest( 468 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed), 469 ubs.drop_front(numProcessed), steps.drop_front(numProcessed), 470 iteratorTypes.drop_front(numProcessed), remainderProcInfo, 471 bodyBuilderFn, ivStorage); 472 }); 473 return; 474 } 475 case DistributionMethod::Cyclic: { 476 // Generate a single parallel loop-nest operation for all outermost 477 // parallel loops and recurse. 478 b.create<scf::ParallelOp>( 479 loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), 480 steps.take_front(numProcessed), 481 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { 482 ivStorage.append(localIvs.begin(), localIvs.end()); 483 generateParallelLoopNest( 484 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed), 485 ubs.drop_front(numProcessed), steps.drop_front(numProcessed), 486 iteratorTypes.drop_front(numProcessed), remainderProcInfo, 487 bodyBuilderFn, ivStorage); 488 }); 489 return; 490 } 491 case DistributionMethod::CyclicNumProcsGeNumIters: { 492 // Check (for the processed loops) that the iteration is in-bounds. 493 ArithBuilder ab(b, loc); 494 Value cond = ab.slt(lbs[0], ubs[0]); 495 for (unsigned i = 1; i < numProcessed; ++i) 496 cond = ab._and(cond, ab.slt(lbs[i], ubs[i])); 497 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); 498 b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) { 499 generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed), 500 ubs.drop_front(numProcessed), 501 steps.drop_front(numProcessed), 502 iteratorTypes.drop_front(numProcessed), 503 remainderProcInfo, bodyBuilderFn, ivStorage); 504 b.create<scf::YieldOp>(loc, ValueRange{}); 505 }); 506 return; 507 } 508 case DistributionMethod::CyclicNumProcsEqNumIters: 509 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed 510 // with inner loop generation. 511 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); 512 generateParallelLoopNest( 513 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), 514 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed), 515 remainderProcInfo, bodyBuilderFn, ivStorage); 516 return; 517 } 518 } 519 520 /// Specialization for generating a mix of parallel and sequential scf loops. 521 template <> 522 void GenerateLoopNest<scf::ParallelOp>::doit( 523 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp, 524 ArrayRef<utils::IteratorType> iteratorTypes, 525 function_ref<scf::ValueVector(OpBuilder &, Location, ValueRange, 526 ValueRange)> 527 bodyBuilderFn, 528 ArrayRef<linalg::ProcInfo> procInfo) { 529 SmallVector<Value> iterArgInitValues; 530 if (!linalgOp.hasPureBufferSemantics()) 531 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits()); 532 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); 533 // This function may be passed more iterator types than ranges. 534 assert(iteratorTypes.size() >= loopRanges.size() && 535 "expected iterator type for all ranges"); 536 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && 537 "expected proc information for all loops when present"); 538 iteratorTypes = iteratorTypes.take_front(loopRanges.size()); 539 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs; 540 unsigned numLoops = iteratorTypes.size(); 541 ivs.reserve(numLoops); 542 lbsStorage.reserve(numLoops); 543 ubsStorage.reserve(numLoops); 544 stepsStorage.reserve(numLoops); 545 546 // Get the loop lb, ub, and step. 547 unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage); 548 549 // Modify the lb, ub, and step based on the distribution options. 550 for (const auto &it : llvm::enumerate(procInfo)) { 551 if (it.value().distributionMethod != linalg::DistributionMethod::None) { 552 updateBoundsForCyclicDistribution( 553 b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()], 554 ubsStorage[it.index()], stepsStorage[it.index()]); 555 } 556 } 557 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); 558 generateParallelLoopNest( 559 b, loc, lbs, ubs, steps, iteratorTypes, procInfo, 560 [&](OpBuilder &b, Location loc, ValueRange ivs) { 561 bodyBuilderFn(b, loc, ivs, linalgOp->getOperands()); 562 }, 563 ivs); 564 565 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); 566 } 567 568 static Operation *materializeTiledShape(OpBuilder &builder, Location loc, 569 Value valueToTile, 570 const SliceParameters &sliceParams) { 571 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType()); 572 auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) 573 .Case([&](MemRefType) { 574 return builder.create<memref::SubViewOp>( 575 loc, valueToTile, sliceParams.offsets, 576 sliceParams.sizes, sliceParams.strides); 577 }) 578 .Case([&](RankedTensorType) { 579 return builder.create<tensor::ExtractSliceOp>( 580 loc, valueToTile, sliceParams.offsets, 581 sliceParams.sizes, sliceParams.strides); 582 }) 583 .Default([](ShapedType) -> Operation * { 584 llvm_unreachable("Unexpected shaped type"); 585 }); 586 return sliceOp; 587 } 588 589 Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, 590 ArrayRef<OpFoldResult> tileSizes, AffineMap map, 591 ArrayRef<OpFoldResult> lbs, 592 ArrayRef<OpFoldResult> ubs, 593 ArrayRef<OpFoldResult> subShapeSizes, 594 bool omitPartialTileCheck) { 595 SliceParameters sliceParams = 596 computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs, 597 ubs, subShapeSizes, omitPartialTileCheck); 598 return materializeTiledShape(builder, loc, valueToTile, sliceParams); 599 } 600 601 SliceParameters 602 computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, 603 ArrayRef<OpFoldResult> tileSizes, AffineMap map, 604 ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs, 605 ArrayRef<OpFoldResult> subShapeSizes, 606 bool omitPartialTileCheck) { 607 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType()); 608 assert(shapedType && "only shaped types can be tiled"); 609 ArrayRef<int64_t> shape = shapedType.getShape(); 610 int64_t rank = shapedType.getRank(); 611 612 // Compute offsets/sizes/strides for the tile. 613 SliceParameters sliceParams; 614 sliceParams.offsets.reserve(rank); 615 sliceParams.sizes.reserve(rank); 616 sliceParams.strides.reserve(rank); 617 for (unsigned r = 0; r < rank; ++r) { 618 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r); 619 if (!isTiled(map.getSubMap({r}), tileSizes)) { 620 sliceParams.offsets.push_back(builder.getIndexAttr(0)); 621 OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r); 622 sliceParams.sizes.push_back(dim); 623 sliceParams.strides.push_back(builder.getIndexAttr(1)); 624 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n"); 625 continue; 626 } 627 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n"); 628 629 // Tiling creates a new slice at the proper index, the slice step is 1 630 // (i.e. the op does not subsample, stepping occurs in the loop). 631 auto m = map.getSubMap({r}); 632 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n"); 633 IRRewriter rewriter(builder); 634 OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs); 635 sliceParams.offsets.push_back(offset); 636 OpFoldResult closedIntSize = 637 makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes); 638 // Resulting size needs to be made half open interval again. 639 AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext()); 640 OpFoldResult size = 641 makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize); 642 LLVM_DEBUG(llvm::dbgs() 643 << "computeSliceParameters: raw size: " << size << "\n"); 644 LLVM_DEBUG(llvm::dbgs() 645 << "computeSliceParameters: new offset: " << offset << "\n"); 646 sliceParams.strides.push_back(builder.getIndexAttr(1)); 647 648 if (omitPartialTileCheck) { 649 // We statically know that the partial/boundary tile condition is 650 // unnecessary. 651 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); 652 sliceParams.sizes.push_back(size); 653 continue; 654 } 655 656 // The size of the subview / extract_slice should be trimmed to avoid 657 // out-of-bounds accesses, unless: 658 // a. We statically know the subshape size divides the shape size evenly. 659 // b. The subshape size is 1. According to the way the loops are set up, 660 // tensors with "0" dimensions would never be constructed. 661 int64_t shapeSize = shape[r]; 662 std::optional<int64_t> sizeCst = getConstantIntValue(size); 663 auto hasTileSizeOne = sizeCst && *sizeCst == 1; 664 auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) && 665 ((shapeSize % *sizeCst) == 0); 666 if (!hasTileSizeOne && !dividesEvenly) { 667 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize 668 << ", size: " << size 669 << ": make sure in bound with affine.min\n"); 670 671 AffineExpr dim0, dim1, dim2; 672 MLIRContext *context = builder.getContext(); 673 bindDims(context, dim0, dim1, dim2); 674 675 // Get the dimension size for this dimension. We need to first calculate 676 // the max index and then plus one. This is important because for 677 // convolution ops, we have its input window dimension's affine map of the 678 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window 679 // dimension and `s0` is stride. Directly use the dimension size of 680 // output/filer window dimensions will cause incorrect calculation. 681 AffineMap minusOneMap = AffineMap::inferFromExprList( 682 {ArrayRef<AffineExpr>{dim0 - 1}}, context) 683 .front(); 684 AffineMap plusOneMap = AffineMap::inferFromExprList( 685 {ArrayRef<AffineExpr>{dim0 + 1}}, context) 686 .front(); 687 SmallVector<OpFoldResult> maxIndices = 688 llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) { 689 return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap, 690 {ub}); 691 })); 692 OpFoldResult maxIndex = 693 makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices); 694 OpFoldResult d = 695 makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex}); 696 697 // Compute min(dim - offset, size) to avoid out-of-bounds accesses. 698 AffineMap minMap = AffineMap::inferFromExprList( 699 {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context) 700 .front(); 701 size = 702 makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset}); 703 } 704 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n"); 705 sliceParams.sizes.push_back(size); 706 } 707 return sliceParams; 708 } 709 710 SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc, 711 ArrayRef<OpFoldResult> ivs, 712 ArrayRef<OpFoldResult> tileSizes) { 713 SmallVector<OpFoldResult> offsets; 714 for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { 715 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n"); 716 bool isTiled = !isZeroIndex(tileSizes[idx]); 717 offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0)); 718 LLVM_DEBUG(llvm::dbgs() 719 << "computeTileOffsets: " << offsets.back() << "\n"); 720 } 721 return offsets; 722 } 723 724 SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc, 725 ArrayRef<OpFoldResult> tileSizes, 726 ArrayRef<OpFoldResult> sizeBounds) { 727 SmallVector<OpFoldResult> sizes; 728 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { 729 bool isTiled = !isZeroIndex(tileSizes[idx]); 730 // Before composing, we need to make range a closed interval. 731 OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx]; 732 AffineExpr d0 = getAffineDimExpr(0, b.getContext()); 733 IRRewriter rewriter(b); 734 sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size)); 735 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n"); 736 } 737 return sizes; 738 } 739 740 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) { 741 if (op.hasPureBufferSemantics()) 742 return {}; 743 return llvm::to_vector( 744 llvm::map_range(op.getDpsInitsMutable(), [&](OpOperand &opOperand) { 745 return operands[opOperand.getOperandNumber()].getType(); 746 })); 747 } 748 749 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc, 750 LinalgOp op, ValueRange operands, 751 ValueRange results) { 752 if (op.hasPureBufferSemantics()) 753 return {}; 754 SmallVector<Value> tensorResults; 755 tensorResults.reserve(results.size()); 756 // Insert a insert_slice for each output tensor. 757 unsigned resultIdx = 0; 758 for (OpOperand &opOperand : op.getDpsInitsMutable()) { 759 // TODO: use an interface/adaptor to avoid leaking position in 760 // `tiledOperands`. 761 Value outputTensor = operands[opOperand.getOperandNumber()]; 762 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) { 763 Value inserted = builder.create<tensor::InsertSliceOp>( 764 loc, sliceOp.getSource().getType(), results[resultIdx], 765 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(), 766 sliceOp.getStrides(), sliceOp.getStaticOffsets(), 767 sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); 768 tensorResults.push_back(inserted); 769 } else { 770 tensorResults.push_back(results[resultIdx]); 771 } 772 ++resultIdx; 773 } 774 return tensorResults; 775 } 776 777 SmallVector<std::optional<SliceParameters>> 778 computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, 779 ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs, 780 ArrayRef<OpFoldResult> tileSizes, 781 ArrayRef<OpFoldResult> sizeBounds, 782 bool omitPartialTileCheck) { 783 assert(ivs.size() == static_cast<size_t>(llvm::count_if( 784 llvm::make_range(tileSizes.begin(), tileSizes.end()), 785 [](OpFoldResult v) { return !isZeroIndex(v); })) && 786 "expected as many ivs as non-zero sizes"); 787 788 // Construct (potentially temporary) mins and maxes on which to apply maps 789 // that define tile subshapes. 790 SmallVector<OpFoldResult> lbs = 791 computeTileOffsets(builder, loc, ivs, tileSizes); 792 SmallVector<OpFoldResult> subShapeSizes = 793 computeTileSizes(builder, loc, tileSizes, sizeBounds); 794 795 assert(static_cast<int64_t>(valuesToTile.size()) <= 796 linalgOp->getNumOperands() && 797 "more value to tile than operands."); 798 SmallVector<std::optional<SliceParameters>> allSliceParams; 799 allSliceParams.reserve(valuesToTile.size()); 800 for (auto [opOperand, val] : 801 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) { 802 Value shapedOp = val; 803 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); 804 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand); 805 // Use `opOperand` as is if it is not tiled and not an output tensor. Having 806 // an extract/insert slice pair for all output tensors simplifies follow up 807 // transformations such as padding and bufferization since the 808 // extract/insert slice pairs make the accessed iteration argument 809 // subdomains explicit. 810 811 Type operandType = opOperand.get().getType(); 812 if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) && 813 linalgOp.isDpsInit(&opOperand))) { 814 allSliceParams.push_back(std::nullopt); 815 LLVM_DEBUG(llvm::dbgs() 816 << ": not tiled: use shape: " << operandType << "\n"); 817 continue; 818 } 819 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n"); 820 821 allSliceParams.push_back(computeSliceParameters( 822 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes, 823 omitPartialTileCheck)); 824 } 825 826 return allSliceParams; 827 } 828 829 SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc, 830 LinalgOp linalgOp, ValueRange valuesToTile, 831 ArrayRef<OpFoldResult> ivs, 832 ArrayRef<OpFoldResult> tileSizes, 833 ArrayRef<OpFoldResult> sizeBounds, 834 bool omitPartialTileCheck) { 835 SmallVector<std::optional<SliceParameters>> allSliceParameter = 836 computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs, 837 tileSizes, sizeBounds, omitPartialTileCheck); 838 SmallVector<Value> tiledShapes; 839 for (auto item : llvm::zip(valuesToTile, allSliceParameter)) { 840 Value valueToTile = std::get<0>(item); 841 std::optional<SliceParameters> sliceParams = std::get<1>(item); 842 tiledShapes.push_back( 843 sliceParams.has_value() 844 ? materializeTiledShape(builder, loc, valueToTile, *sliceParams) 845 ->getResult(0) 846 : valueToTile); 847 } 848 return tiledShapes; 849 } 850 851 void offsetIndices(OpBuilder &b, LinalgOp linalgOp, 852 ArrayRef<OpFoldResult> offsets) { 853 IRRewriter rewriter(b); 854 offsetIndices(rewriter, linalgOp, offsets); 855 } 856 857 void offsetIndices(RewriterBase &b, LinalgOp linalgOp, 858 ArrayRef<OpFoldResult> offsets) { 859 if (!linalgOp.hasIndexSemantics()) 860 return; 861 862 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) { 863 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) 864 continue; 865 OpBuilder::InsertionGuard guard(b); 866 b.setInsertionPointAfter(indexOp); 867 AffineExpr index, offset; 868 bindDims(b.getContext(), index, offset); 869 OpFoldResult applied = makeComposedFoldedAffineApply( 870 b, indexOp.getLoc(), index + offset, 871 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]}); 872 Value materialized = 873 getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied); 874 b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) { 875 return use.getOwner() != materialized.getDefiningOp(); 876 }); 877 } 878 } 879 880 /// Get the reassociation maps to fold the result of a extract_slice (or source 881 /// of a insert_slice) operation with given offsets, and sizes to its 882 /// rank-reduced version. This is only done for the cases where the size is 1 883 /// and offset is 0. Strictly speaking the offset 0 is not required in general, 884 /// but non-zero offsets are not handled by SPIR-V backend at this point (and 885 /// potentially cannot be handled). 886 std::optional<SmallVector<ReassociationIndices>> 887 getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) { 888 SmallVector<ReassociationIndices> reassociation; 889 ReassociationIndices curr; 890 for (const auto &it : llvm::enumerate(mixedSizes)) { 891 auto dim = it.index(); 892 auto size = it.value(); 893 curr.push_back(dim); 894 auto attr = llvm::dyn_cast_if_present<Attribute>(size); 895 if (attr && cast<IntegerAttr>(attr).getInt() == 1) 896 continue; 897 reassociation.emplace_back(ReassociationIndices{}); 898 std::swap(reassociation.back(), curr); 899 } 900 // When the reassociations are not empty, then fold the remaining 901 // unit-dimensions into the last dimension. If the reassociations so far is 902 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. 903 if (!curr.empty() && !reassociation.empty()) 904 reassociation.back().append(curr.begin(), curr.end()); 905 return reassociation; 906 } 907 908 } // namespace linalg 909 } // namespace mlir 910