1 //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" 16 #include "mlir/Dialect/PDL/IR/PDL.h" 17 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" 18 #include "mlir/Dialect/SCF/IR/SCF.h" 19 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 20 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/Builders.h" 24 #include "mlir/IR/BuiltinAttributes.h" 25 #include "mlir/IR/IRMapping.h" 26 #include "mlir/IR/MLIRContext.h" 27 #include "mlir/IR/OpDefinition.h" 28 #include "mlir/IR/Visitors.h" 29 #include "mlir/Support/LLVM.h" 30 #include "llvm/ADT/STLExtras.h" 31 #include "llvm/ADT/SmallVector.h" 32 #include "llvm/Support/Debug.h" 33 34 using namespace mlir; 35 using namespace mlir::gpu; 36 using namespace mlir::transform; 37 using namespace mlir::transform::gpu; 38 39 #define DEBUG_TYPE "gpu-transforms" 40 41 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 42 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") 43 44 namespace { 45 46 /// Return a flattened thread id for the workgroup with given sizes. 47 static Value buildLinearThreadId(RewriterBase &rewriter, Location loc, 48 ArrayRef<OpFoldResult> blockDimsOfr) { 49 LLVM_DEBUG(llvm::interleaveComma( 50 blockDimsOfr, 51 DBGS() << "----buildLinearThreadId with blockDimsOfr: "); 52 llvm::dbgs() << "\n"); 53 assert(blockDimsOfr.size() == 3 && "expected 3 workgroup sizes"); 54 AffineExpr tx, ty, tz, BDX, BDY; 55 bindDims(rewriter.getContext(), tx, ty, tz); 56 bindSymbols(rewriter.getContext(), BDX, BDY); 57 IndexType indexType = rewriter.getIndexType(); 58 SmallVector<OpFoldResult> threadsAndWorkGroups{ 59 rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x).getResult(), 60 rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y).getResult(), 61 rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z).getResult()}; 62 threadsAndWorkGroups.push_back(blockDimsOfr[0]); 63 threadsAndWorkGroups.push_back(blockDimsOfr[1]); 64 OpFoldResult ofr = makeComposedFoldedAffineApply( 65 rewriter, loc, tx + ty * BDX + tz * BDX * BDY, threadsAndWorkGroups); 66 return getValueOrCreateConstantIndexOp(rewriter, loc, ofr); 67 } 68 69 /// Builder for gpu::BlockIdOps used in mapping scf.forall to blocks. 70 /// The `idBuilder` method returns 3-D values used for indexing rewrites as well 71 /// as 3-D sizes for predicate generation. 72 struct GpuBlockIdBuilder : public GpuIdBuilder { 73 74 GpuBlockIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims, 75 ArrayRef<int64_t> mappingSizes) 76 : GpuIdBuilder(blockDims, mappingSizes) { 77 mappingAttributes = {GPUBlockMappingAttr::get(ctx, Blocks::DimX), 78 GPUBlockMappingAttr::get(ctx, Blocks::DimY), 79 GPUBlockMappingAttr::get(ctx, Blocks::DimZ)}, 80 idBuilder = [](RewriterBase &rewriter, Location loc, 81 ArrayRef<int64_t> forallMappingSizes) { 82 IndexType indexType = rewriter.getIndexType(); 83 SmallVector<Value> ids{ 84 rewriter.create<BlockIdOp>(loc, indexType, Dimension::x), 85 rewriter.create<BlockIdOp>(loc, indexType, Dimension::y), 86 rewriter.create<BlockIdOp>(loc, indexType, Dimension::z)}; 87 // Return 3-D ids for indexing rewrite and 3-D sizes and ids for 88 // predicate generation. 89 return IdBuilderResult{ids, SmallVector<int64_t>{forallMappingSizes}, 90 ids}; 91 }; 92 } 93 }; 94 95 /// Builder for gpu::ThreadIdOp used in mapping scf.forall to thread ids without 96 /// any reindexing. 97 /// The `idBuilder` method returns 3-D values used for indexing rewrites as well 98 /// as 3-D sizes for predicate generation. 99 struct GpuThreadIdBuilder : public GpuIdBuilder { 100 GpuThreadIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims, 101 ArrayRef<int64_t> mappingSizes) 102 : GpuIdBuilder(blockDims, mappingSizes) { 103 mappingAttributes = {GPUThreadMappingAttr::get(ctx, Threads::DimX), 104 GPUThreadMappingAttr::get(ctx, Threads::DimY), 105 GPUThreadMappingAttr::get(ctx, Threads::DimZ)}; 106 idBuilder = [](RewriterBase &rewriter, Location loc, 107 ArrayRef<int64_t> forallMappingSizes) { 108 IndexType indexType = rewriter.getIndexType(); 109 SmallVector<Value> ids{ 110 rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x), 111 rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y), 112 rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z)}; 113 // Return 3-D ids for indexing rewrite and 3-D sizes and ids for 114 // predicate generation. 115 return IdBuilderResult{ids, SmallVector<int64_t>{forallMappingSizes}, 116 ids}; 117 }; 118 } 119 }; 120 121 /// Builder for warp ids used in mapping scf.forall to warps. 122 /// This builder requires a specification of the number of warps along each 123 /// dimension to more finely control mapping to warps as well a predication than 124 /// by solely analyzing the IR. 125 /// The `idBuilder` method returns 3-D values used for indexing rewrites as well 126 /// as 3-D sizes for predicate generation. 127 struct GpuWarpIdBuilder : public GpuIdBuilder { 128 GpuWarpIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims, 129 ArrayRef<int64_t> mappingSizes) 130 : GpuIdBuilder(blockDims, mappingSizes) { 131 mappingAttributes = {GPUWarpMappingAttr::get(ctx, Warps::DimX), 132 GPUWarpMappingAttr::get(ctx, Warps::DimY), 133 GPUWarpMappingAttr::get(ctx, Warps::DimZ)}; 134 idBuilder = [this](RewriterBase &rewriter, Location loc, 135 ArrayRef<int64_t> forallMappingSizes) { 136 // Build the linear warp id and decompose it in the basis of 137 // `forallMappingSizes`. 138 Value linearId = buildLinearThreadId(rewriter, loc, this->blockDimsOfr); 139 AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); 140 OpFoldResult warpIdOfr = makeComposedFoldedAffineApply( 141 rewriter, loc, d0.floorDiv(kWarpSize), {linearId}); 142 Value warpId = getValueOrCreateConstantIndexOp(rewriter, loc, warpIdOfr); 143 SmallVector<int64_t> reverseBasisSizes( 144 llvm::reverse(this->availableMappingSizes)); 145 SmallVector<int64_t> strides = computeStrides(reverseBasisSizes); 146 SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides); 147 SmallVector<Value> ids; 148 for (AffineExpr e : delinearizingExprs) 149 ids.push_back(makeComposedAffineApply(rewriter, loc, e, warpId)); 150 151 // clang-format off 152 LDBG("----linearId: " << linearId); 153 LDBG("----warpId: " << warpId); 154 LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes, 155 DBGS() << "--delinearization basis: "); 156 llvm::dbgs() << "\n"; 157 llvm::interleaveComma(strides, 158 DBGS() << "--delinearization strides: "); 159 llvm::dbgs() << "\n"; 160 llvm::interleaveComma(delinearizingExprs, 161 DBGS() << "--delinearization exprs: "); 162 llvm::dbgs() << "\n"; 163 llvm::interleaveComma(ids, DBGS() << "--ids: "); 164 llvm::dbgs() << "\n";); 165 // clang-format on 166 167 // Return 3-D ids for indexing rewrite and 3-D sizes and ids for 168 // predicate generation. 169 return IdBuilderResult{ids, SmallVector<int64_t>{forallMappingSizes}, 170 ids}; 171 }; 172 } 173 174 /// Static specification of the warp size. 175 /// In the future this may be configured by the transformation. 176 static constexpr int64_t kWarpSize = 32; 177 }; 178 179 /// Builder for linear ids used in mapping scf.forall to reindexed threads. 180 /// The `idBuilder` method returns 3-D values used for indexing rewrites as well 181 /// as 1-D sizes for predicate generation. 182 struct GpuLinearIdBuilder : public GpuIdBuilder { 183 GpuLinearIdBuilder(MLIRContext *ctx, ArrayRef<OpFoldResult> blockDims, 184 ArrayRef<int64_t> mappingSizes) 185 : GpuIdBuilder(blockDims, mappingSizes) { 186 mappingAttributes = {GPULinearIdMappingAttr::get(ctx, LinearId::DimX), 187 GPULinearIdMappingAttr::get(ctx, LinearId::DimY), 188 GPULinearIdMappingAttr::get(ctx, LinearId::DimZ)}; 189 idBuilder = [this](RewriterBase &rewriter, Location loc, 190 ArrayRef<int64_t> forallMappingSizes) { 191 // Build the linear thread id and decompose it in the basis of 192 // `forallMappingSizes`. 193 Value linearId = buildLinearThreadId(rewriter, loc, this->blockDimsOfr); 194 SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes)); 195 SmallVector<int64_t> strides = computeStrides(reverseBasisSizes); 196 AffineExpr d0; 197 bindDims(rewriter.getContext(), d0); 198 SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides); 199 SmallVector<Value> ids; 200 for (AffineExpr e : delinearizingExprs) 201 ids.push_back(makeComposedAffineApply(rewriter, loc, e, linearId)); 202 203 // clang-format off 204 LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes, 205 DBGS() << "--delinearization basis: "); 206 llvm::dbgs() << "\n"; 207 llvm::interleaveComma(strides, 208 DBGS() << "--delinearization strides: "); 209 llvm::dbgs() << "\n"; 210 llvm::interleaveComma(delinearizingExprs, 211 DBGS() << "--delinearization exprs: "); 212 llvm::dbgs() << "\n"; 213 llvm::interleaveComma(ids, DBGS() << "--ids: "); 214 llvm::dbgs() << "\n";); 215 // clang-format on 216 217 // Compute and return the 1-D actual mapping size spanned by the linearId, 218 // it will be used to predicate against the linearized total number of 219 // threads. 220 int64_t actualMappingSize = 1; 221 for (int64_t s : forallMappingSizes) 222 actualMappingSize *= s; 223 224 // Return 3-D ids for indexing rewrite and 1-D size and id for 225 // predicate generation. 226 return IdBuilderResult{ids, SmallVector<int64_t>{actualMappingSize}, 227 SmallVector<Value>{linearId}}; 228 }; 229 } 230 }; 231 232 } // namespace 233 234 static DiagnosedSilenceableFailure 235 definiteFailureHelper(std::optional<TransformOpInterface> transformOp, 236 Operation *target, const Twine &message) { 237 if (transformOp.has_value()) 238 return transformOp->emitDefiniteFailure() << message; 239 return emitDefiniteFailure(target, message); 240 } 241 242 /// Check if given mapping attributes are one of the desired attributes 243 static DiagnosedSilenceableFailure 244 checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp, 245 scf::ForallOp forallOp) { 246 if (!forallOp.getMapping().has_value()) 247 return definiteFailureHelper(transformOp, forallOp, 248 "mapping must be present"); 249 250 bool hasBlockMapping = 251 llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { 252 return attr.isa<GPUBlockMappingAttr>(); 253 }); 254 bool hasThreadMapping = 255 llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { 256 return attr.isa<GPUThreadMappingAttr>(); 257 }); 258 bool hasWarpMapping = 259 llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { 260 return attr.isa<GPUWarpMappingAttr>(); 261 }); 262 bool hasLinearMapping = 263 llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { 264 return attr.isa<GPULinearIdMappingAttr>(); 265 }); 266 int64_t countMappingTypes = 0; 267 countMappingTypes += hasBlockMapping ? 1 : 0; 268 countMappingTypes += hasThreadMapping ? 1 : 0; 269 countMappingTypes += hasWarpMapping ? 1 : 0; 270 countMappingTypes += hasLinearMapping ? 1 : 0; 271 if (countMappingTypes > 1) { 272 return definiteFailureHelper( 273 transformOp, forallOp, 274 "cannot mix different mapping types, use nesting"); 275 } 276 277 DenseSet<Attribute> seen; 278 for (Attribute map : forallOp.getMapping()->getValue()) { 279 if (seen.contains(map)) { 280 return definiteFailureHelper( 281 transformOp, forallOp, 282 "duplicated attribute, cannot map different loops " 283 "to the same processor"); 284 } 285 seen.insert(map); 286 } 287 288 return DiagnosedSilenceableFailure::success(); 289 } 290 291 static DiagnosedSilenceableFailure 292 verifyGpuMapping(std::optional<TransformOpInterface> transformOp, 293 scf::ForallOp forallOp) { 294 // Check the types of the mapping attributes match. 295 DiagnosedSilenceableFailure typeRes = 296 checkMappingAttributeTypes(transformOp, forallOp); 297 if (!typeRes.succeeded()) 298 return typeRes; 299 300 // Perform other non-types verifications. 301 if (!forallOp.isNormalized()) 302 return definiteFailureHelper(transformOp, forallOp, 303 "unsupported non-normalized loops"); 304 if (forallOp.getNumResults() > 0) 305 return definiteFailureHelper(transformOp, forallOp, 306 "only bufferized scf.forall can be mapped"); 307 if (forallOp.getRank() > 3) 308 return definiteFailureHelper(transformOp, forallOp, 309 "scf.forall with rank > 3 does not lower"); 310 if (llvm::any_of(forallOp.getMixedUpperBound(), [&](OpFoldResult ofr) { 311 return !getConstantIntValue(ofr).has_value(); 312 })) { 313 return definiteFailureHelper(transformOp, forallOp, 314 "unsupported dynamic sizes in forall op"); 315 } 316 return DiagnosedSilenceableFailure::success(); 317 } 318 319 /// Determines if the size of the kernel configuration is supported by the 320 /// GPU architecture being used. It presently makes use of CUDA limitations, 321 /// however that aspect may be enhanced for other GPUs. 322 static DiagnosedSilenceableFailure checkGpuLimits( 323 TransformOpInterface transformOp, std::optional<int64_t> gridDimX, 324 std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ, 325 std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY, 326 std::optional<int64_t> blockDimZ) { 327 328 static constexpr int maxTotalBlockdim = 1024; 329 static constexpr int maxBlockdimx = 1024; 330 static constexpr int maxBlockdimy = 1024; 331 static constexpr int maxBlockdimz = 64; 332 static constexpr int maxTotalGriddim = 2147483647; 333 static constexpr int maxGriddimx = 2147483647; 334 static constexpr int maxGriddimy = 65535; 335 static constexpr int maxGriddimz = 65535; 336 337 if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) > 338 maxTotalBlockdim || 339 (gridDimX.value_or(1) * gridDimY.value_or(1) * gridDimZ.value_or(1)) > 340 maxTotalGriddim || 341 blockDimX.value_or(1) > maxBlockdimx || 342 blockDimY.value_or(1) > maxBlockdimy || 343 blockDimZ.value_or(1) > maxBlockdimz || 344 gridDimY.value_or(1) > maxGriddimy || 345 gridDimZ.value_or(1) > maxGriddimz || 346 gridDimX.value_or(1) > maxGriddimx) { 347 return transformOp.emitSilenceableError() 348 << "Trying to launch a GPU kernel with grid_dims = (" 349 << gridDimX.value_or(1) << ", " << gridDimY.value_or(1) << ", " 350 << gridDimZ.value_or(1) << ") block_dims = (" 351 << blockDimX.value_or(1) << ", " << blockDimY.value_or(1) << ", " 352 << blockDimZ.value_or(1) << "). It is larger than the limits."; 353 } 354 return DiagnosedSilenceableFailure::success(); 355 } 356 357 /// Creates an empty-body gpu::LaunchOp using the provided kernel settings 358 /// and put a terminator within. 359 static DiagnosedSilenceableFailure 360 createGpuLaunch(RewriterBase &rewriter, Location loc, 361 TransformOpInterface transformOp, LaunchOp &launchOp, 362 std::optional<int64_t> gridDimX = std::nullopt, 363 std::optional<int64_t> gridDimY = std::nullopt, 364 std::optional<int64_t> gridDimZ = std::nullopt, 365 std::optional<int64_t> blockDimX = std::nullopt, 366 std::optional<int64_t> blockDimY = std::nullopt, 367 std::optional<int64_t> blockDimZ = std::nullopt) { 368 DiagnosedSilenceableFailure diag = 369 checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX, 370 blockDimY, blockDimZ); 371 if (!diag.succeeded()) 372 return diag; 373 374 auto createConst = [&](int dim) { 375 return rewriter.create<arith::ConstantIndexOp>(loc, dim); 376 }; 377 OpBuilder::InsertionGuard guard(rewriter); 378 Value one = createConst(1); 379 Value gridSizeX = gridDimX.has_value() ? createConst(gridDimX.value()) : one; 380 Value gridSizeY = gridDimY.has_value() ? createConst(gridDimY.value()) : one; 381 Value gridSizeZ = gridDimZ.has_value() ? createConst(gridDimZ.value()) : one; 382 Value blkSizeX = blockDimX.has_value() ? createConst(blockDimX.value()) : one; 383 Value blkSizeY = blockDimY.has_value() ? createConst(blockDimY.value()) : one; 384 Value blkSizeZ = blockDimZ.has_value() ? createConst(blockDimZ.value()) : one; 385 launchOp = rewriter.create<LaunchOp>(loc, gridSizeX, gridSizeY, gridSizeZ, 386 blkSizeX, blkSizeY, blkSizeZ); 387 rewriter.setInsertionPointToEnd(&launchOp.getBody().front()); 388 rewriter.create<TerminatorOp>(loc); 389 return DiagnosedSilenceableFailure::success(); 390 } 391 392 /// Alter kernel configuration of the given kernel. 393 static DiagnosedSilenceableFailure 394 alterGpuLaunch(IRRewriter &rewriter, LaunchOp gpuLaunch, 395 TransformOpInterface transformOp, 396 std::optional<int64_t> gridDimX = std::nullopt, 397 std::optional<int64_t> gridDimY = std::nullopt, 398 std::optional<int64_t> gridDimZ = std::nullopt, 399 std::optional<int64_t> blockDimX = std::nullopt, 400 std::optional<int64_t> blockDimY = std::nullopt, 401 std::optional<int64_t> blockDimZ = std::nullopt) { 402 DiagnosedSilenceableFailure diag = 403 checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX, 404 blockDimY, blockDimZ); 405 if (!diag.succeeded()) 406 return diag; 407 408 KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues(); 409 OpBuilder::InsertionGuard guard(rewriter); 410 rewriter.setInsertionPointAfterValue(currentBlockdim.x); 411 auto createConstValue = [&](int dim) { 412 return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(), 413 dim); 414 }; 415 416 if (gridDimX.has_value()) 417 gpuLaunch.getGridSizeXMutable().assign(createConstValue(gridDimX.value())); 418 if (gridDimY.has_value()) 419 gpuLaunch.getGridSizeYMutable().assign(createConstValue(gridDimY.value())); 420 if (gridDimZ.has_value()) 421 gpuLaunch.getGridSizeZMutable().assign(createConstValue(gridDimZ.value())); 422 if (blockDimX.has_value()) 423 gpuLaunch.getBlockSizeXMutable().assign( 424 createConstValue(blockDimX.value())); 425 if (blockDimY.has_value()) 426 gpuLaunch.getBlockSizeYMutable().assign( 427 createConstValue(blockDimY.value())); 428 if (blockDimZ.has_value()) 429 gpuLaunch.getBlockSizeZMutable().assign( 430 createConstValue(blockDimZ.value())); 431 return DiagnosedSilenceableFailure::success(); 432 } 433 434 /// Struct to return the result of the rewrite of a forall operation. 435 struct ForallRewriteResult { 436 SmallVector<int64_t> mappingSizes; 437 SmallVector<Value> mappingIds; 438 }; 439 440 /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR. 441 template <typename OpTy, typename OperationOrBlock> 442 static void 443 replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, 444 OperationOrBlock *parent, Value replacement, 445 ArrayRef<int64_t> availableMappingSizes) { 446 parent->walk([&](OpTy idOp) { 447 if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1) 448 rewriter.replaceAllUsesWith(idOp.getResult(), replacement); 449 }); 450 } 451 452 static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( 453 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, 454 scf::ForallOp forallOp, ForallRewriteResult &result, 455 ArrayRef<int64_t> availableMappingSizes, const GpuIdBuilder &gpuIdBuilder) { 456 LDBG("--start rewriteOneForallCommonImpl"); 457 458 // Step 0. GPU-specific verifications. There is no better place to anchor 459 // those right now: the ForallOp is target-independent and the transform 460 // op does not apply to individual ForallOp. 461 DiagnosedSilenceableFailure diag = verifyGpuMapping(transformOp, forallOp); 462 if (!diag.succeeded()) 463 return diag; 464 465 // Step 1. Complete the mapping to a full mapping (with 1s) if necessary. 466 SmallVector<int64_t> tmpMappingSizes = llvm::to_vector( 467 llvm::map_range(forallOp.getMixedUpperBound(), [](OpFoldResult ofr) { 468 auto maybeStaticValue = getConstantIntValue(ofr); 469 assert(maybeStaticValue && "expected static value"); 470 return maybeStaticValue.value(); 471 })); 472 SmallVector<Attribute> forallMappingAttrs = 473 llvm::to_vector(forallOp.getMapping()->getValue()); 474 for (auto attr : gpuIdBuilder.mappingAttributes) { 475 if (llvm::is_contained(forallMappingAttrs, attr)) 476 continue; 477 forallMappingAttrs.push_back(attr); 478 tmpMappingSizes.push_back(1); 479 } 480 LLVM_DEBUG( 481 llvm::interleaveComma( 482 tmpMappingSizes, 483 DBGS() << "----tmpMappingSizes extracted from scf.forall op: "); 484 llvm::dbgs() << "\n"); 485 486 // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. 487 auto comparator = [&](DeviceMappingAttrInterface a, 488 DeviceMappingAttrInterface b) -> bool { 489 return a.getMappingId() < b.getMappingId(); 490 }; 491 SmallVector<int64_t> forallMappingSizes = 492 getValuesSortedByKey(forallMappingAttrs, tmpMappingSizes, comparator); 493 LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes, 494 DBGS() << "----forallMappingSizes: "); 495 llvm::dbgs() << "\n"; llvm::interleaveComma( 496 forallMappingAttrs, DBGS() << "----mappingAttrs: "); 497 llvm::dbgs() << "\n"); 498 499 // Step 3. Generate the mappingIdOps using the provided generator. 500 Location loc = forallOp.getLoc(); 501 OpBuilder::InsertionGuard guard(rewriter); 502 rewriter.setInsertionPoint(forallOp); 503 IdBuilderResult builderResult = 504 gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes); 505 506 // Step 4. Map the induction variables to the mappingIdOps, this may involve a 507 // permutation. 508 SmallVector<Value> mappingIdOps = builderResult.mappingIdOps; 509 IRMapping bvm; 510 for (auto [iv, dim] : 511 llvm::zip_equal(forallOp.getInductionVars(), 512 ArrayRef<Attribute>{forallMappingAttrs}.take_front( 513 forallOp.getInductionVars().size()))) { 514 Value peIdOp = mappingIdOps[static_cast<int64_t>( 515 dim.cast<DeviceMappingAttrInterface>().getMappingId())]; 516 bvm.map(iv, peIdOp); 517 } 518 519 // Step 5. If the availableMappingSizes are already known, create conditionals 520 // to predicate the region. Otherwise, the current forall determines the 521 // availableMappingSizes and no predication occurs. 522 Value predicate; 523 if (!availableMappingSizes.empty()) { 524 SmallVector<int64_t> predicateMappingSizes = 525 builderResult.predicateMappingSizes; 526 SmallVector<Value> predicateIdOps = builderResult.predicateIdOps; 527 // clang-format off 528 LLVM_DEBUG( 529 llvm::interleaveComma( 530 predicateMappingSizes, DBGS() << "----predicateMappingSizes: "); 531 llvm::dbgs() << "\n"; 532 llvm::interleaveComma( 533 availableMappingSizes, DBGS() << "----availableMappingSizes: "); 534 llvm::dbgs() << "\n"; 535 llvm::interleaveComma(predicateIdOps, DBGS() << "----predicateIdOps: "); 536 llvm::dbgs() << "\n"); 537 // clang-format on 538 for (auto [id, mappingSize, availableMappingSize] : llvm::zip_equal( 539 predicateIdOps, predicateMappingSizes, availableMappingSizes)) { 540 if (mappingSize > availableMappingSize) { 541 return definiteFailureHelper( 542 transformOp, forallOp, 543 "Trying to map to fewer GPU threads than loop iterations but " 544 "overprovisioning is not yet supported. " 545 "Try additional tiling of the before mapping or map to more " 546 "threads."); 547 } 548 if (mappingSize == availableMappingSize) 549 continue; 550 Value idx = rewriter.create<arith::ConstantIndexOp>(loc, mappingSize); 551 Value tmpPredicate = rewriter.create<arith::CmpIOp>( 552 loc, arith::CmpIPredicate::ult, id, idx); 553 LDBG("----predicate: " << tmpPredicate); 554 predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate, 555 tmpPredicate) 556 : tmpPredicate; 557 } 558 } 559 560 // Step 6. Move the body of forallOp. 561 // Erase the terminator first, it will not be used. 562 rewriter.eraseOp(forallOp.getTerminator()); 563 Block *targetBlock; 564 Block::iterator insertionPoint; 565 if (predicate) { 566 // Step 6.a. If predicated, move at the beginning. 567 auto ifOp = rewriter.create<scf::IfOp>(loc, predicate, 568 /*withElseRegion=*/false); 569 targetBlock = ifOp.thenBlock(); 570 insertionPoint = ifOp.thenBlock()->begin(); 571 } else { 572 // Step 6.b. Otherwise, move inline just at the rewriter insertion 573 // point. 574 targetBlock = forallOp->getBlock(); 575 insertionPoint = rewriter.getInsertionPoint(); 576 } 577 Block &sourceBlock = forallOp.getRegion().front(); 578 targetBlock->getOperations().splice(insertionPoint, 579 sourceBlock.getOperations()); 580 581 // Step 7. RAUW indices. 582 for (Value loopIndex : forallOp.getInductionVars()) { 583 Value threadIdx = bvm.lookup(loopIndex); 584 rewriter.replaceAllUsesWith(loopIndex, threadIdx); 585 } 586 587 // Step 8. Erase old op. 588 rewriter.eraseOp(forallOp); 589 590 result = ForallRewriteResult{forallMappingSizes, mappingIdOps}; 591 return DiagnosedSilenceableFailure::success(); 592 } 593 594 //===----------------------------------------------------------------------===// 595 // MapForallToBlocks 596 //===----------------------------------------------------------------------===// 597 598 DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl( 599 RewriterBase &rewriter, TransformOpInterface transformOp, 600 scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims, 601 const GpuIdBuilder &gpuIdBuilder) { 602 LDBG("Start mapForallToBlocksImpl"); 603 604 Location loc = forallOp.getLoc(); 605 Block *parentBlock = forallOp->getBlock(); 606 Value zero; 607 { 608 // Create an early zero index value for replacements and immediately reset 609 // the insertion point. 610 OpBuilder::InsertionGuard guard(rewriter); 611 rewriter.setInsertionPointToStart(parentBlock); 612 zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 613 } 614 615 SmallVector<int64_t> anyAvailableMappingSizes; 616 ForallRewriteResult rewriteResult; 617 // Pass an empty anyAvailableMappingSizes. 618 DiagnosedSilenceableFailure diag = 619 rewriteOneForallCommonImpl(rewriter, transformOp, forallOp, rewriteResult, 620 anyAvailableMappingSizes, gpuIdBuilder); 621 622 // Return if anything goes wrong, use silenceable failure as a match failure. 623 if (!diag.succeeded()) 624 return diag; 625 626 // Set the gridDims that act as a return. 627 gridDims = rewriteResult.mappingSizes; 628 629 // Replace ids of dimensions known to be 1 by 0 to simplify the IR. 630 // Here, the result of mapping determines the available mapping sizes. 631 replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero, 632 gridDims); 633 634 return DiagnosedSilenceableFailure::success(); 635 } 636 637 DiagnosedSilenceableFailure 638 mlir::transform::gpu::findTopLevelForallOp(Operation *target, 639 scf::ForallOp &topLevelForallOp, 640 TransformOpInterface transformOp) { 641 auto walkResult = target->walk([&](scf::ForallOp forallOp) { 642 if (forallOp->getParentOfType<scf::ForallOp>()) 643 return WalkResult::advance(); 644 if (topLevelForallOp) 645 // TODO: Handle multiple forall if they are independent. 646 return WalkResult::interrupt(); 647 topLevelForallOp = forallOp; 648 return WalkResult::advance(); 649 }); 650 651 if (walkResult.wasInterrupted()) 652 return transformOp.emitSilenceableError() 653 << "could not find a unique topLevel scf.forall"; 654 return DiagnosedSilenceableFailure::success(); 655 } 656 657 DiagnosedSilenceableFailure 658 transform::MapForallToBlocks::applyToOne(Operation *target, 659 ApplyToEachResultList &results, 660 transform::TransformState &state) { 661 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target); 662 IRRewriter rewriter(getContext()); 663 auto transformOp = cast<TransformOpInterface>(getOperation()); 664 665 if (!getGenerateGpuLaunch() && !gpuLaunch) { 666 DiagnosedSilenceableFailure diag = 667 emitSilenceableError() 668 << "Given target is not gpu.launch, set `generate_gpu_launch` " 669 "attribute"; 670 diag.attachNote(target->getLoc()) << "when applied to this payload op"; 671 return diag; 672 } 673 674 scf::ForallOp topLevelForallOp; 675 DiagnosedSilenceableFailure diag = mlir::transform::gpu::findTopLevelForallOp( 676 target, topLevelForallOp, transformOp); 677 if (!diag.succeeded()) { 678 diag.attachNote(target->getLoc()) << "when applied to this payload op"; 679 return diag; 680 } 681 682 SmallVector<int64_t> gridDims{getGridDims()}; 683 if (!getGenerateGpuLaunch() && gridDims.size() != 3) 684 return transformOp.emitDefiniteFailure("transform require size-3 mapping"); 685 686 OpBuilder::InsertionGuard guard(rewriter); 687 rewriter.setInsertionPoint(topLevelForallOp); 688 689 // Generate gpu launch here and move the forall inside 690 if (getGenerateGpuLaunch()) { 691 DiagnosedSilenceableFailure diag = 692 createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch); 693 if (!diag.succeeded()) { 694 return diag; 695 } 696 rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); 697 Operation *newForallOp = rewriter.clone(*topLevelForallOp); 698 rewriter.eraseOp(topLevelForallOp); 699 topLevelForallOp = cast<scf::ForallOp>(newForallOp); 700 } 701 702 GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), {}, {}); 703 diag = mlir::transform::gpu::mapForallToBlocksImpl( 704 rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder); 705 if (!diag.succeeded()) 706 return diag; 707 708 // Set the GPU launch configuration for the grid dims late, this is subject to 709 // IR inspection. 710 diag = alterGpuLaunch(rewriter, gpuLaunch, 711 cast<TransformOpInterface>(getOperation()), gridDims[0], 712 gridDims[1], gridDims[2]); 713 714 results.push_back(gpuLaunch); 715 return diag; 716 } 717 718 //===----------------------------------------------------------------------===// 719 // MapNestedForallToThreads 720 //===----------------------------------------------------------------------===// 721 722 DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl( 723 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, 724 scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes, 725 bool syncAfterDistribute, const GpuIdBuilder &gpuIdBuilder) { 726 // Ignore cases with different attributes than this builder supports. 727 for (Attribute map : forallOp.getMapping()->getValue()) { 728 if (!llvm::is_contained(gpuIdBuilder.mappingAttributes, map)) { 729 LDBG("--skip " << map); 730 LLVM_DEBUG(llvm::interleaveComma(gpuIdBuilder.mappingAttributes, 731 DBGS() << "----not in: "); 732 llvm::dbgs() << "\n";); 733 return emitSilenceableFailure(forallOp); 734 } 735 } 736 737 Location loc = forallOp.getLoc(); 738 OpBuilder::InsertionGuard g(rewriter); 739 // Insert after to allow for syncthreads after `forall` is erased. 740 rewriter.setInsertionPointAfter(forallOp); 741 ForallRewriteResult rewriteResult; 742 DiagnosedSilenceableFailure diag = 743 rewriteOneForallCommonImpl(rewriter, transformOp, forallOp, rewriteResult, 744 availableMappingSizes, gpuIdBuilder); 745 746 // Return if anything goes wrong, use silenceable failure as a match failure. 747 if (!diag.succeeded()) 748 return diag; 749 750 // Add a syncthreads if needed. TODO: warpsync 751 if (syncAfterDistribute) 752 rewriter.create<BarrierOp>(loc); 753 754 return DiagnosedSilenceableFailure::success(); 755 } 756 757 DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl( 758 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, 759 Operation *target, ArrayRef<int64_t> blockDims, ArrayRef<int64_t> warpDims, 760 bool syncAfterDistribute) { 761 LDBG("Start mapNestedForallToThreadsImpl"); 762 MLIRContext *ctx = rewriter.getContext(); 763 SmallVector<OpFoldResult> blockDimsOfr = 764 getAsIndexOpFoldResult(ctx, blockDims); 765 766 if (blockDims.size() != 3) 767 return definiteFailureHelper(transformOp, target, 768 "requires size-3 thread mapping"); 769 if (!warpDims.empty()) { 770 if (warpDims.size() != 3) 771 return definiteFailureHelper(transformOp, target, 772 "requires empty or size-3 warp mapping"); 773 } 774 775 // Create an early zero index value for replacements. 776 Location loc = target->getLoc(); 777 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 778 DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); 779 WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) { 780 //===--------------------------------------------------------------------===// 781 // Mapping to warp ids. 782 //===--------------------------------------------------------------------===// 783 if (!warpDims.empty()) { 784 LLVM_DEBUG( 785 llvm::interleaveComma( 786 warpDims, DBGS() << "+mapNestedForallToThreadsImpl warpDims: "); 787 llvm::dbgs() << "\n"); 788 LLVM_DEBUG(llvm::interleaveComma( 789 blockDimsOfr, DBGS() << "--warpDims with blockDimsOfr: "); 790 llvm::dbgs() << "\n"); 791 GpuWarpIdBuilder gpuWarpIdBuilder(ctx, blockDimsOfr, warpDims); 792 diag = mlir::transform::gpu::mapOneForallToThreadsImpl( 793 rewriter, transformOp, forallOp, warpDims, syncAfterDistribute, 794 gpuWarpIdBuilder); 795 // Use silenceable failure to encode "failure to match" and pass 796 // through. 797 if (diag.isDefiniteFailure()) 798 return WalkResult::interrupt(); 799 if (diag.succeeded()) 800 return WalkResult::skip(); 801 } 802 803 //===--------------------------------------------------------------------===// 804 // Mapping to linear ids. 805 //===--------------------------------------------------------------------===// 806 LDBG("+mapNestedForallToThreadsImpl linearDims"); 807 LLVM_DEBUG(llvm::interleaveComma( 808 blockDimsOfr, DBGS() << "--linearDims with blockDimsOfr: "); 809 llvm::dbgs() << "\n"); 810 int64_t numThreads = 1; 811 for (int64_t b : blockDims) 812 numThreads *= b; 813 GpuLinearIdBuilder gpuLinearIdBuilder(ctx, blockDimsOfr, numThreads); 814 diag = mlir::transform::gpu::mapOneForallToThreadsImpl( 815 rewriter, transformOp, forallOp, numThreads, syncAfterDistribute, 816 gpuLinearIdBuilder); 817 // Use silenceable failure to encode "failure to match" and pass through. 818 if (diag.isDefiniteFailure()) 819 return WalkResult::interrupt(); 820 if (diag.succeeded()) 821 return WalkResult::skip(); 822 823 //===--------------------------------------------------------------------===// 824 // Mapping to block ids (happens last so we can replay ThreadIdOp). 825 //===--------------------------------------------------------------------===// 826 LLVM_DEBUG( 827 llvm::interleaveComma( 828 blockDimsOfr, DBGS() << "mapNestedForallToThreadsImpl blockDims: "); 829 llvm::dbgs() << "\n"); 830 GpuThreadIdBuilder gpuThreadIdBuilder(ctx, blockDimsOfr, blockDims); 831 diag = mlir::transform::gpu::mapOneForallToThreadsImpl( 832 rewriter, transformOp, forallOp, blockDims, syncAfterDistribute, 833 gpuThreadIdBuilder); 834 // Use silenceable failure to encode "failure to match" and pass through. 835 if (diag.isDefiniteFailure()) 836 return WalkResult::interrupt(); 837 838 return WalkResult::advance(); 839 }); 840 if (walkResult.wasInterrupted()) 841 return diag; 842 843 // Replace ids of dimensions known to be 1 by 0 to simplify the IR. 844 // Here, the result of mapping determines the available mapping sizes. 845 replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero, 846 blockDims); 847 848 return DiagnosedSilenceableFailure::success(); 849 } 850 851 void transform::MapNestedForallToThreads::getEffects( 852 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 853 onlyReadsHandle(getTarget(), effects); 854 modifiesPayload(effects); 855 } 856 857 DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne( 858 Operation *target, ApplyToEachResultList &results, TransformState &state) { 859 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target); 860 auto transformOp = cast<TransformOpInterface>(getOperation()); 861 862 // Basic high-level verifications. 863 if (!gpuLaunch) 864 return emitSilenceableError() << "Given target is not a gpu.launch"; 865 866 // Mapping to block ids. 867 SmallVector<int64_t> blockDims{getBlockDims()}; 868 869 DiagnosedSilenceableFailure diag = 870 checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt, 871 blockDims[0], blockDims[1], blockDims[2]); 872 if (diag.isSilenceableFailure()) { 873 diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large"; 874 return diag; 875 } 876 877 // Set the GPU launch configuration for the block dims early, this is not 878 // subject to IR inspection. 879 IRRewriter rewriter(getContext()); 880 diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt, 881 std::nullopt, std::nullopt, blockDims[0], blockDims[1], 882 blockDims[2]); 883 884 rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); 885 diag = 886 mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims, 887 getWarpDims(), getSyncAfterDistribute()); 888 889 return diag; 890 } 891 892 //===----------------------------------------------------------------------===// 893 // Transform op registration 894 //===----------------------------------------------------------------------===// 895 896 namespace { 897 /// Registers new ops and declares PDL as dependent dialect since the 898 /// additional ops are using PDL types for operands and results. 899 class GPUTransformDialectExtension 900 : public transform::TransformDialectExtension< 901 GPUTransformDialectExtension> { 902 public: 903 GPUTransformDialectExtension() { 904 declareDependentDialect<pdl::PDLDialect>(); 905 declareGeneratedDialect<scf::SCFDialect>(); 906 declareGeneratedDialect<arith::ArithDialect>(); 907 declareGeneratedDialect<GPUDialect>(); 908 registerTransformOps< 909 #define GET_OP_LIST 910 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" 911 >(); 912 } 913 }; 914 } // namespace 915 916 #define GET_OP_CLASSES 917 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" 918 919 void mlir::gpu::registerTransformDialectExtension(DialectRegistry ®istry) { 920 registry.addExtensions<GPUTransformDialectExtension>(); 921 } 922