1 //===- GPUTransformOps.cpp - Implementation of GPU transform ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" 10 11 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 12 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" 13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 18 #include "mlir/Dialect/GPU/TransformOps/Utils.h" 19 #include "mlir/Dialect/GPU/Transforms/Passes.h" 20 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" 23 #include "mlir/Dialect/SCF/IR/SCF.h" 24 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 25 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 26 #include "mlir/Dialect/Utils/IndexingUtils.h" 27 #include "mlir/Dialect/Vector/IR/VectorOps.h" 28 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 29 #include "mlir/IR/AffineExpr.h" 30 #include "mlir/IR/Builders.h" 31 #include "mlir/IR/BuiltinAttributes.h" 32 #include "mlir/IR/IRMapping.h" 33 #include "mlir/IR/MLIRContext.h" 34 #include "mlir/IR/OpDefinition.h" 35 #include "mlir/IR/Visitors.h" 36 #include "mlir/Support/LLVM.h" 37 #include "mlir/Transforms/DialectConversion.h" 38 #include "llvm/ADT/STLExtras.h" 39 #include "llvm/ADT/SmallVector.h" 40 #include "llvm/ADT/TypeSwitch.h" 41 #include "llvm/Support/Debug.h" 42 #include "llvm/Support/ErrorHandling.h" 43 #include <type_traits> 44 45 using namespace mlir; 46 using namespace mlir::gpu; 47 using namespace mlir::transform; 48 using namespace mlir::transform::gpu; 49 50 #define DEBUG_TYPE "gpu-transforms" 51 #define DEBUG_TYPE_ALIAS "gpu-transforms-alias" 52 53 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 54 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") 55 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") 56 57 //===----------------------------------------------------------------------===// 58 // Apply...ConversionPatternsOp 59 //===----------------------------------------------------------------------===// 60 61 void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns( 62 TypeConverter &typeConverter, RewritePatternSet &patterns) { 63 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); 64 // NVVM uses alloca in the default address space to represent private 65 // memory allocations, so drop private annotations. NVVM uses address 66 // space 3 for shared memory. NVVM uses the default address space to 67 // represent global memory. 68 // Used in populateGpuToNVVMConversionPatternsso attaching here for now. 69 // TODO: We should have a single to_nvvm_type_converter. 70 populateGpuMemorySpaceAttributeConversions( 71 llvmTypeConverter, [](AddressSpace space) -> unsigned { 72 switch (space) { 73 case AddressSpace::Global: 74 return static_cast<unsigned>( 75 NVVM::NVVMMemorySpace::kGlobalMemorySpace); 76 case AddressSpace::Workgroup: 77 return static_cast<unsigned>( 78 NVVM::NVVMMemorySpace::kSharedMemorySpace); 79 case AddressSpace::Private: 80 return 0; 81 } 82 llvm_unreachable("unknown address space enum value"); 83 return 0; 84 }); 85 // Used in GPUToNVVM/WmmaOpsToNvvm.cpp so attaching here for now. 86 // TODO: We should have a single to_nvvm_type_converter. 87 llvmTypeConverter.addConversion( 88 [&](MMAMatrixType type) -> Type { return convertMMAToLLVMType(type); }); 89 populateGpuToNVVMConversionPatterns(llvmTypeConverter, patterns); 90 } 91 92 LogicalResult 93 transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter( 94 transform::TypeConverterBuilderOpInterface builder) { 95 if (builder.getTypeConverterType() != "LLVMTypeConverter") 96 return emitOpError("expected LLVMTypeConverter"); 97 return success(); 98 } 99 100 void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns( 101 TypeConverter &typeConverter, RewritePatternSet &patterns) { 102 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); 103 populateGpuWMMAToNVVMConversionPatterns(llvmTypeConverter, patterns); 104 } 105 106 LogicalResult 107 transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter( 108 transform::TypeConverterBuilderOpInterface builder) { 109 if (builder.getTypeConverterType() != "LLVMTypeConverter") 110 return emitOpError("expected LLVMTypeConverter"); 111 return success(); 112 } 113 114 void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp:: 115 populatePatterns(TypeConverter &typeConverter, 116 RewritePatternSet &patterns) { 117 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); 118 populateGpuSubgroupReduceOpLoweringPattern(llvmTypeConverter, patterns); 119 } 120 121 LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp:: 122 verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) { 123 if (builder.getTypeConverterType() != "LLVMTypeConverter") 124 return emitOpError("expected LLVMTypeConverter"); 125 return success(); 126 } 127 128 //===----------------------------------------------------------------------===// 129 // Apply...PatternsOp 130 //===----------------------------------------------------------------------===//s 131 132 void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) { 133 populateGpuRewritePatterns(patterns); 134 } 135 136 //===----------------------------------------------------------------------===// 137 // ApplyUnrollVectorsSubgroupMmaOp 138 //===----------------------------------------------------------------------===// 139 140 /// Pick an unrolling order that will allow tensorcore operation to reuse LHS 141 /// register. 142 static std::optional<SmallVector<int64_t>> 143 gpuMmaUnrollOrder(vector::ContractionOp contract) { 144 SmallVector<int64_t> order; 145 // First make reduction the outer dimensions. 146 for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { 147 if (vector::isReductionIterator(iter)) { 148 order.push_back(index); 149 } 150 } 151 152 llvm::SmallDenseSet<int64_t> dims; 153 for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) { 154 dims.insert(cast<AffineDimExpr>(expr).getPosition()); 155 } 156 // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. 157 for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { 158 if (vector::isParallelIterator(iter) && dims.count(index)) { 159 order.push_back(index); 160 } 161 } 162 // Then the remaining parallel loops. 163 for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { 164 if (vector::isParallelIterator(iter) && !dims.count(index)) { 165 order.push_back(index); 166 } 167 } 168 return order; 169 } 170 171 /// Returns the target vector size for the target operation based on the native 172 /// vector size specified with `m`, `n`, and `k`. 173 static std::optional<SmallVector<int64_t>> 174 getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) { 175 if (auto contract = dyn_cast<vector::ContractionOp>(op)) { 176 int64_t contractRank = contract.getIteratorTypes().size(); 177 if (contractRank < 3) 178 return std::nullopt; 179 SmallVector<int64_t> nativeSize(contractRank - 3, 1); 180 nativeSize.append({m, n, k}); 181 return nativeSize; 182 } 183 if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) { 184 int64_t writeRank = writeOp.getVectorType().getRank(); 185 if (writeRank < 2) 186 return std::nullopt; 187 SmallVector<int64_t> nativeSize(writeRank - 2, 1); 188 nativeSize.append({m, n}); 189 return nativeSize; 190 } 191 if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) { 192 // Transfer read ops may need different shapes based on how they are being 193 // used. For simplicity just match the shape used by the extract strided op. 194 VectorType sliceType; 195 for (Operation *users : op->getUsers()) { 196 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users); 197 if (!extract) 198 return std::nullopt; 199 auto vecType = cast<VectorType>(extract.getResult().getType()); 200 if (sliceType && sliceType != vecType) 201 return std::nullopt; 202 sliceType = vecType; 203 } 204 return llvm::to_vector(sliceType.getShape()); 205 } 206 if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) { 207 if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) { 208 // TODO: The condition for unrolling elementwise should be restricted 209 // only to operations that need unrolling (connected to the contract). 210 if (vecType.getRank() < 2) 211 return std::nullopt; 212 213 // First check whether there is a slice to infer the shape from. This is 214 // required for cases where the accumulator type differs from the input 215 // types, in which case we will see an `arith.ext_` between the contract 216 // and transfer_read which needs to be unrolled. 217 VectorType sliceType; 218 for (Operation *users : op->getUsers()) { 219 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users); 220 if (!extract) 221 return std::nullopt; 222 auto vecType = cast<VectorType>(extract.getResult().getType()); 223 if (sliceType && sliceType != vecType) 224 return std::nullopt; 225 sliceType = vecType; 226 } 227 if (sliceType) 228 return llvm::to_vector(sliceType.getShape()); 229 230 // Else unroll for trailing elementwise. 231 SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1); 232 // Map elementwise ops to the output shape. 233 nativeSize.append({m, n}); 234 return nativeSize; 235 } 236 } 237 return std::nullopt; 238 } 239 240 void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns( 241 RewritePatternSet &patterns) { 242 auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> { 243 auto contract = dyn_cast<vector::ContractionOp>(op); 244 if (!contract) 245 return std::nullopt; 246 return gpuMmaUnrollOrder(contract); 247 }; 248 249 int64_t m = getM(); 250 int64_t n = getN(); 251 int64_t k = getK(); 252 auto nativeShapeFn = 253 [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> { 254 return getSubgroupMmaNativeVectorSize(op, m, n, k); 255 }; 256 vector::populateVectorUnrollPatterns( 257 patterns, vector::UnrollVectorOptions() 258 .setNativeShapeFn(nativeShapeFn) 259 .setUnrollTraversalOrderFn(unrollOrder)); 260 } 261 262 //===----------------------------------------------------------------------===// 263 // EliminateBarriersOp 264 //===----------------------------------------------------------------------===// 265 266 void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) { 267 populateGpuEliminateBarriersPatterns(patterns); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // Block and thread mapping utilities. 272 //===----------------------------------------------------------------------===// 273 274 namespace { 275 /// Local types used for mapping verification. 276 struct MappingKind {}; 277 struct BlockMappingKind : MappingKind {}; 278 struct ThreadMappingKind : MappingKind {}; 279 } // namespace 280 281 static DiagnosedSilenceableFailure 282 definiteFailureHelper(std::optional<TransformOpInterface> transformOp, 283 Operation *target, const Twine &message) { 284 if (transformOp.has_value()) 285 return transformOp->emitDefiniteFailure() << message; 286 return emitDefiniteFailure(target, message); 287 } 288 289 /// Check if given mapping attributes are one of the desired attributes 290 template <typename MappingKindType> 291 static DiagnosedSilenceableFailure 292 checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp, 293 scf::ForallOp forallOp) { 294 if (!forallOp.getMapping().has_value()) { 295 return definiteFailureHelper(transformOp, forallOp, 296 "scf.forall op requires a mapping attribute"); 297 } 298 299 bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(), 300 llvm::IsaPred<GPUBlockMappingAttr>); 301 bool hasWarpgroupMapping = llvm::any_of( 302 forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>); 303 bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(), 304 llvm::IsaPred<GPUWarpMappingAttr>); 305 bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(), 306 llvm::IsaPred<GPUThreadMappingAttr>); 307 int64_t countMappingTypes = 0; 308 countMappingTypes += hasBlockMapping ? 1 : 0; 309 countMappingTypes += hasWarpgroupMapping ? 1 : 0; 310 countMappingTypes += hasWarpMapping ? 1 : 0; 311 countMappingTypes += hasThreadMapping ? 1 : 0; 312 if (countMappingTypes > 1) { 313 return definiteFailureHelper( 314 transformOp, forallOp, 315 "cannot mix different mapping types, use nesting"); 316 } 317 if (std::is_same<MappingKindType, BlockMappingKind>::value && 318 !hasBlockMapping) { 319 return definiteFailureHelper( 320 transformOp, forallOp, 321 "scf.forall op requires a mapping attribute of kind 'block'"); 322 } 323 if (std::is_same<MappingKindType, ThreadMappingKind>::value && 324 !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) { 325 return definiteFailureHelper(transformOp, forallOp, 326 "scf.forall op requires a mapping attribute " 327 "of kind 'thread' or 'warp'"); 328 } 329 330 DenseSet<Attribute> seen; 331 for (Attribute map : forallOp.getMapping()->getValue()) { 332 if (seen.contains(map)) { 333 return definiteFailureHelper( 334 transformOp, forallOp, 335 "duplicate attribute, cannot map different loops " 336 "to the same mapping id"); 337 } 338 seen.insert(map); 339 } 340 341 auto isLinear = [](Attribute a) { 342 return cast<DeviceMappingAttrInterface>(a).isLinearMapping(); 343 }; 344 if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) && 345 !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) { 346 return definiteFailureHelper( 347 transformOp, forallOp, 348 "cannot mix linear and non-linear mapping modes"); 349 } 350 351 return DiagnosedSilenceableFailure::success(); 352 } 353 354 template <typename MappingKindType> 355 static DiagnosedSilenceableFailure 356 verifyGpuMapping(std::optional<TransformOpInterface> transformOp, 357 scf::ForallOp forallOp) { 358 // Check the types of the mapping attributes match. 359 DiagnosedSilenceableFailure typeRes = 360 checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp); 361 if (!typeRes.succeeded()) 362 return typeRes; 363 364 // Perform other non-types verifications. 365 if (!forallOp.isNormalized()) 366 return definiteFailureHelper(transformOp, forallOp, 367 "unsupported non-normalized loops"); 368 if (forallOp.getNumResults() > 0) 369 return definiteFailureHelper(transformOp, forallOp, 370 "only bufferized scf.forall can be mapped"); 371 bool useLinearMapping = cast<DeviceMappingAttrInterface>( 372 forallOp.getMapping()->getValue().front()) 373 .isLinearMapping(); 374 // TODO: This would be more natural with support for Optional<EnumParameter> 375 // in GPUDeviceMappingAttr. 376 int64_t maxNumMappingsSupported = 377 useLinearMapping ? (getMaxEnumValForMappingId() - 378 static_cast<uint64_t>(MappingId::DimZ)) 379 : 3; 380 if (forallOp.getRank() > maxNumMappingsSupported) { 381 return definiteFailureHelper(transformOp, forallOp, 382 "scf.forall with rank > ") 383 << maxNumMappingsSupported 384 << " does not lower for the specified mapping attribute type"; 385 } 386 auto numParallelIterations = 387 getConstantIntValues(forallOp.getMixedUpperBound()); 388 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) { 389 return definiteFailureHelper( 390 transformOp, forallOp, 391 "requires statically sized, normalized forall op"); 392 } 393 return DiagnosedSilenceableFailure::success(); 394 } 395 396 /// Struct to return the result of the rewrite of a forall operation. 397 struct ForallRewriteResult { 398 SmallVector<int64_t> mappingSizes; 399 SmallVector<Value> mappingIds; 400 }; 401 402 /// Helper to replace ids of dimensions known to be 1 by 0 to simplify the IR. 403 template <typename OpTy, typename OperationOrBlock> 404 static void 405 replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, 406 OperationOrBlock *parent, Value replacement, 407 ArrayRef<int64_t> availableMappingSizes) { 408 parent->walk([&](OpTy idOp) { 409 if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1) 410 rewriter.replaceAllUsesWith(idOp.getResult(), replacement); 411 }); 412 } 413 414 static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( 415 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, 416 scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes, 417 ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) { 418 LDBG("--start rewriteOneForallCommonImpl"); 419 420 // Step 1. Complete the mapping to a full mapping (with 1s) if necessary. 421 auto numParallelIterations = 422 getConstantIntValues(forallOp.getMixedUpperBound()); 423 assert(forallOp.isNormalized() && numParallelIterations.has_value() && 424 "requires statically sized, normalized forall op"); 425 SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value(); 426 SetVector<Attribute> forallMappingAttrs; 427 forallMappingAttrs.insert(forallOp.getMapping()->getValue().begin(), 428 forallOp.getMapping()->getValue().end()); 429 auto comparator = [](Attribute a, Attribute b) -> bool { 430 return cast<DeviceMappingAttrInterface>(a).getMappingId() < 431 cast<DeviceMappingAttrInterface>(b).getMappingId(); 432 }; 433 434 // Step 1.b. In the linear case, compute the max mapping to avoid needlessly 435 // mapping all dimensions. In the 3-D mapping case we need to map all 436 // dimensions. 437 DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>( 438 *llvm::max_element(forallMappingAttrs, comparator)); 439 DeviceMappingAttrInterface maxLinearMapping; 440 if (maxMapping.isLinearMapping()) 441 maxLinearMapping = maxMapping; 442 for (auto attr : gpuIdBuilder.mappingAttributes) { 443 // If attr overflows, just skip. 444 if (maxLinearMapping && comparator(maxLinearMapping, attr)) 445 continue; 446 // Try to insert. If element was already present, just continue. 447 if (!forallMappingAttrs.insert(attr)) 448 continue; 449 // Otherwise, we have a new insertion without a size -> use size 1. 450 tmpMappingSizes.push_back(1); 451 } 452 LLVM_DEBUG( 453 llvm::interleaveComma( 454 tmpMappingSizes, 455 DBGS() << "----tmpMappingSizes extracted from scf.forall op: "); 456 llvm::dbgs() << "\n"); 457 458 // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. 459 SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey( 460 forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator); 461 LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes, 462 DBGS() << "----forallMappingSizes: "); 463 llvm::dbgs() << "\n"; llvm::interleaveComma( 464 forallMappingAttrs, DBGS() << "----forallMappingAttrs: "); 465 llvm::dbgs() << "\n"); 466 467 // Step 3. Generate the mappingIdOps using the provided generator. 468 Location loc = forallOp.getLoc(); 469 OpBuilder::InsertionGuard guard(rewriter); 470 rewriter.setInsertionPoint(forallOp); 471 SmallVector<int64_t> originalBasis(availableMappingSizes); 472 bool originalBasisWasProvided = !originalBasis.empty(); 473 if (!originalBasisWasProvided) { 474 originalBasis = forallMappingSizes; 475 while (originalBasis.size() < 3) 476 originalBasis.push_back(1); 477 } 478 479 IdBuilderResult builderResult = 480 gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis); 481 482 // Step 4. Map the induction variables to the mappingIdOps, this may involve 483 // a permutation. 484 SmallVector<Value> mappingIdOps = builderResult.mappingIdOps; 485 IRMapping bvm; 486 for (auto [iv, dim] : llvm::zip_equal( 487 forallOp.getInductionVars(), 488 forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) { 489 auto mappingAttr = cast<DeviceMappingAttrInterface>(dim); 490 Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()]; 491 bvm.map(iv, peIdOp); 492 } 493 494 // Step 5. If the originalBasis is already known, create conditionals to 495 // predicate the region. Otherwise, the current forall determines the 496 // originalBasis and no predication occurs. 497 Value predicate; 498 if (originalBasisWasProvided) { 499 SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes; 500 SmallVector<int64_t> availableMappingSizes = 501 builderResult.availableMappingSizes; 502 SmallVector<Value> activeIdOps = builderResult.activeIdOps; 503 // clang-format off 504 LLVM_DEBUG( 505 llvm::interleaveComma( 506 activeMappingSizes, DBGS() << "----activeMappingSizes: "); 507 llvm::dbgs() << "\n"; 508 llvm::interleaveComma( 509 availableMappingSizes, DBGS() << "----availableMappingSizes: "); 510 llvm::dbgs() << "\n"; 511 llvm::interleaveComma(activeIdOps, DBGS() << "----activeIdOps: "); 512 llvm::dbgs() << "\n"); 513 // clang-format on 514 for (auto [activeId, activeMappingSize, availableMappingSize] : 515 llvm::zip_equal(activeIdOps, activeMappingSizes, 516 availableMappingSizes)) { 517 if (activeMappingSize > availableMappingSize) { 518 return definiteFailureHelper( 519 transformOp, forallOp, 520 "Trying to map to fewer GPU threads than loop iterations but " 521 "overprovisioning is not yet supported. " 522 "Try additional tiling of the before mapping or map to more " 523 "threads."); 524 } 525 if (activeMappingSize == availableMappingSize) 526 continue; 527 Value idx = 528 rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize); 529 Value tmpPredicate = rewriter.create<arith::CmpIOp>( 530 loc, arith::CmpIPredicate::ult, activeId, idx); 531 LDBG("----predicate: " << tmpPredicate); 532 predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate, 533 tmpPredicate) 534 : tmpPredicate; 535 } 536 } 537 538 // Step 6. Move the body of forallOp. 539 // Erase the terminator first, it will not be used. 540 rewriter.eraseOp(forallOp.getTerminator()); 541 Block *targetBlock; 542 Block::iterator insertionPoint; 543 if (predicate) { 544 // Step 6.a. If predicated, move at the beginning. 545 auto ifOp = rewriter.create<scf::IfOp>(loc, predicate, 546 /*withElseRegion=*/false); 547 targetBlock = ifOp.thenBlock(); 548 insertionPoint = ifOp.thenBlock()->begin(); 549 } else { 550 // Step 6.b. Otherwise, move inline just at the rewriter insertion 551 // point. 552 targetBlock = forallOp->getBlock(); 553 insertionPoint = rewriter.getInsertionPoint(); 554 } 555 Block &sourceBlock = forallOp.getRegion().front(); 556 targetBlock->getOperations().splice(insertionPoint, 557 sourceBlock.getOperations()); 558 559 // Step 7. RAUW indices. 560 for (Value loopIndex : forallOp.getInductionVars()) { 561 Value threadIdx = bvm.lookup(loopIndex); 562 rewriter.replaceAllUsesWith(loopIndex, threadIdx); 563 } 564 565 // Step 8. Erase old op. 566 rewriter.eraseOp(forallOp); 567 568 LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes, 569 DBGS() << "----result forallMappingSizes: "); 570 llvm::dbgs() << "\n"; llvm::interleaveComma( 571 mappingIdOps, DBGS() << "----result mappingIdOps: "); 572 llvm::dbgs() << "\n"); 573 574 result = ForallRewriteResult{forallMappingSizes, mappingIdOps}; 575 return DiagnosedSilenceableFailure::success(); 576 } 577 578 //===----------------------------------------------------------------------===// 579 // MapForallToBlocks 580 //===----------------------------------------------------------------------===// 581 582 DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl( 583 RewriterBase &rewriter, TransformOpInterface transformOp, 584 scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims, 585 const GpuIdBuilder &gpuIdBuilder) { 586 LDBG("Start mapForallToBlocksImpl"); 587 588 { 589 // GPU-specific verifications. There is no better place to anchor 590 // those right now: the ForallOp is target-independent and the transform 591 // op does not apply to individual ForallOp. 592 DiagnosedSilenceableFailure diag = 593 verifyGpuMapping<BlockMappingKind>(transformOp, forallOp); 594 if (!diag.succeeded()) 595 return diag; 596 } 597 598 Location loc = forallOp.getLoc(); 599 Block *parentBlock = forallOp->getBlock(); 600 Value zero; 601 { 602 // Create an early zero index value for replacements and immediately reset 603 // the insertion point. 604 OpBuilder::InsertionGuard guard(rewriter); 605 rewriter.setInsertionPointToStart(parentBlock); 606 zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 607 } 608 609 ForallRewriteResult rewriteResult; 610 DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl( 611 rewriter, transformOp, forallOp, 612 /*availableMappingSizes=*/gridDims, rewriteResult, gpuIdBuilder); 613 614 // Return if anything goes wrong, use silenceable failure as a match 615 // failure. 616 if (!diag.succeeded()) 617 return diag; 618 619 // If gridDims was not provided already, set it from the return. 620 if (gridDims.empty()) { 621 gridDims = rewriteResult.mappingSizes; 622 while (gridDims.size() < 3) 623 gridDims.push_back(1); 624 } 625 assert(gridDims.size() == 3 && "Need 3-D gridDims"); 626 627 // Replace ids of dimensions known to be 1 by 0 to simplify the IR. 628 // Here, the result of mapping determines the available mapping sizes. 629 replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero, 630 rewriteResult.mappingSizes); 631 632 return DiagnosedSilenceableFailure::success(); 633 } 634 635 DiagnosedSilenceableFailure 636 mlir::transform::gpu::findTopLevelForallOp(Operation *target, 637 scf::ForallOp &topLevelForallOp, 638 TransformOpInterface transformOp) { 639 auto walkResult = target->walk([&](scf::ForallOp forallOp) { 640 if (forallOp->getParentOfType<scf::ForallOp>()) 641 return WalkResult::advance(); 642 if (topLevelForallOp) 643 // TODO: Handle multiple forall if they are independent. 644 return WalkResult::interrupt(); 645 topLevelForallOp = forallOp; 646 return WalkResult::advance(); 647 }); 648 649 if (walkResult.wasInterrupted() || !topLevelForallOp) 650 return transformOp.emitSilenceableError() 651 << "could not find a unique topLevel scf.forall"; 652 return DiagnosedSilenceableFailure::success(); 653 } 654 655 DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne( 656 transform::TransformRewriter &rewriter, Operation *target, 657 ApplyToEachResultList &results, transform::TransformState &state) { 658 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target); 659 auto transformOp = cast<TransformOpInterface>(getOperation()); 660 661 if (!getGenerateGpuLaunch() && !gpuLaunch) { 662 DiagnosedSilenceableFailure diag = 663 emitSilenceableError() 664 << "Given target is not gpu.launch, set `generate_gpu_launch` " 665 "attribute"; 666 diag.attachNote(target->getLoc()) << "when applied to this payload op"; 667 return diag; 668 } 669 670 scf::ForallOp topLevelForallOp; 671 DiagnosedSilenceableFailure diag = mlir::transform::gpu::findTopLevelForallOp( 672 target, topLevelForallOp, transformOp); 673 if (!diag.succeeded()) { 674 diag.attachNote(target->getLoc()) << "when applied to this payload op"; 675 return diag; 676 } 677 assert(topLevelForallOp && "expect an scf.forall"); 678 679 SmallVector<int64_t> gridDims{getGridDims()}; 680 if (!getGenerateGpuLaunch() && gridDims.size() != 3) 681 return transformOp.emitDefiniteFailure("transform require size-3 mapping"); 682 683 OpBuilder::InsertionGuard guard(rewriter); 684 rewriter.setInsertionPoint(topLevelForallOp); 685 686 // Generate gpu launch here and move the forall inside 687 if (getGenerateGpuLaunch()) { 688 DiagnosedSilenceableFailure diag = 689 createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch); 690 if (!diag.succeeded()) 691 return diag; 692 693 rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); 694 Operation *newForallOp = rewriter.clone(*topLevelForallOp); 695 rewriter.eraseOp(topLevelForallOp); 696 topLevelForallOp = cast<scf::ForallOp>(newForallOp); 697 } 698 699 // The BlockIdBuilder adapts to whatever is thrown at it. 700 bool useLinearMapping = false; 701 if (topLevelForallOp.getMapping()) { 702 auto mappingAttr = cast<DeviceMappingAttrInterface>( 703 topLevelForallOp.getMapping()->getValue().front()); 704 useLinearMapping = mappingAttr.isLinearMapping(); 705 } 706 GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping); 707 708 diag = mlir::transform::gpu::mapForallToBlocksImpl( 709 rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder); 710 if (!diag.succeeded()) 711 return diag; 712 713 // Set the GPU launch configuration for the grid dims late, this is 714 // subject to IR inspection. 715 diag = alterGpuLaunch(rewriter, gpuLaunch, 716 cast<TransformOpInterface>(getOperation()), gridDims[0], 717 gridDims[1], gridDims[2]); 718 719 results.push_back(gpuLaunch); 720 return diag; 721 } 722 723 LogicalResult transform::MapForallToBlocks::verify() { 724 if (!getGridDims().empty() && getGridDims().size() != 3) { 725 return emitOpError() << "transform requires empty or size-3 grid_dims"; 726 } 727 return success(); 728 } 729 730 //===----------------------------------------------------------------------===// 731 // MapNestedForallToThreads 732 //===----------------------------------------------------------------------===// 733 734 static DiagnosedSilenceableFailure checkMappingSpec( 735 std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp, 736 ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes, 737 int factor, bool useLinearMapping = false) { 738 if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) { 739 auto diag = definiteFailureHelper( 740 transformOp, forallOp, 741 Twine("3-D mapping: size of threadIdx.x must be a multiple of ") + 742 std::to_string(factor)); 743 return diag; 744 } 745 if (computeProduct(numParallelIterations) * factor > 746 computeProduct(blockOrGridSizes)) { 747 auto diag = definiteFailureHelper( 748 transformOp, forallOp, 749 Twine("the number of required parallel resources (blocks or " 750 "threads) ") + 751 std::to_string(computeProduct(numParallelIterations) * factor) + 752 std::string(" overflows the number of available resources ") + 753 std::to_string(computeProduct(blockOrGridSizes))); 754 return diag; 755 } 756 return DiagnosedSilenceableFailure::success(); 757 } 758 759 static DiagnosedSilenceableFailure 760 getThreadIdBuilder(std::optional<TransformOpInterface> transformOp, 761 scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, 762 int64_t warpSize, GpuIdBuilder &gpuIdBuilder) { 763 auto mappingAttr = cast<DeviceMappingAttrInterface>( 764 forallOp.getMapping()->getValue().front()); 765 bool useLinearMapping = mappingAttr.isLinearMapping(); 766 767 // Sanity checks that may result in runtime verification errors. 768 auto numParallelIterations = 769 getConstantIntValues((forallOp.getMixedUpperBound())); 770 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) { 771 return definiteFailureHelper( 772 transformOp, forallOp, 773 "requires statically sized, normalized forall op"); 774 } 775 int64_t factor = 1; 776 if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) { 777 factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize; 778 } else if (isa<GPUWarpMappingAttr>(mappingAttr)) { 779 factor = warpSize; 780 } 781 DiagnosedSilenceableFailure diag = 782 checkMappingSpec(transformOp, forallOp, numParallelIterations.value(), 783 blockSizes, factor, useLinearMapping); 784 if (!diag.succeeded()) 785 return diag; 786 787 // Start mapping. 788 MLIRContext *ctx = forallOp.getContext(); 789 gpuIdBuilder = 790 TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr) 791 .Case([&](GPUWarpgroupMappingAttr) { 792 return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping); 793 }) 794 .Case([&](GPUWarpMappingAttr) { 795 return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping); 796 }) 797 .Case([&](GPUThreadMappingAttr) { 798 return GpuThreadIdBuilder(ctx, useLinearMapping); 799 }) 800 .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder { 801 llvm_unreachable("unknown mapping attribute"); 802 }); 803 return DiagnosedSilenceableFailure::success(); 804 } 805 806 DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl( 807 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, 808 scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize, 809 bool syncAfterDistribute) { 810 811 { 812 // GPU-specific verifications. There is no better place to anchor 813 // those right now: the ForallOp is target-independent and the transform 814 // op does not apply to individual ForallOp. 815 DiagnosedSilenceableFailure diag = 816 verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp); 817 if (!diag.succeeded()) 818 return diag; 819 } 820 821 GpuIdBuilder gpuIdBuilder; 822 { 823 // Try to construct the id builder, if it fails, return. 824 DiagnosedSilenceableFailure diag = getThreadIdBuilder( 825 transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder); 826 if (!diag.succeeded()) 827 return diag; 828 } 829 830 Location loc = forallOp.getLoc(); 831 OpBuilder::InsertionGuard g(rewriter); 832 // Insert after to allow for syncthreads after `forall` is erased. 833 rewriter.setInsertionPointAfter(forallOp); 834 ForallRewriteResult rewriteResult; 835 DiagnosedSilenceableFailure diag = rewriteOneForallCommonImpl( 836 rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder); 837 if (!diag.succeeded()) 838 return diag; 839 // Add a syncthreads if needed. TODO: warpsync 840 if (syncAfterDistribute) 841 rewriter.create<BarrierOp>(loc); 842 843 return DiagnosedSilenceableFailure::success(); 844 } 845 846 DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl( 847 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, 848 Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize, 849 bool syncAfterDistribute) { 850 LDBG("Start mapNestedForallToThreadsImpl"); 851 if (blockDims.size() != 3) { 852 return definiteFailureHelper(transformOp, target, 853 "requires size-3 thread mapping"); 854 } 855 856 // Create an early zero index value for replacements. 857 Location loc = target->getLoc(); 858 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 859 DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); 860 WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) { 861 diag = mlir::transform::gpu::mapOneForallToThreadsImpl( 862 rewriter, transformOp, forallOp, blockDims, warpSize, 863 syncAfterDistribute); 864 if (diag.isDefiniteFailure()) 865 return WalkResult::interrupt(); 866 if (diag.succeeded()) 867 return WalkResult::skip(); 868 return WalkResult::advance(); 869 }); 870 if (walkResult.wasInterrupted()) 871 return diag; 872 873 // Replace ids of dimensions known to be 1 by 0 to simplify the IR. 874 // Here, the result of mapping determines the available mapping sizes. 875 replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero, 876 blockDims); 877 878 return DiagnosedSilenceableFailure::success(); 879 } 880 881 DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne( 882 transform::TransformRewriter &rewriter, Operation *target, 883 ApplyToEachResultList &results, TransformState &state) { 884 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target); 885 auto transformOp = cast<TransformOpInterface>(getOperation()); 886 887 // Basic high-level verifications. 888 if (!gpuLaunch) 889 return emitSilenceableError() << "Given target is not a gpu.launch"; 890 891 // Mapping to block ids. 892 SmallVector<int64_t> blockDims{getBlockDims()}; 893 DiagnosedSilenceableFailure diag = 894 checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt, 895 blockDims[0], blockDims[1], blockDims[2]); 896 if (diag.isSilenceableFailure()) { 897 diag.attachNote(getLoc()) << getBlockDimsAttrName() << " is too large"; 898 return diag; 899 } 900 901 // Set the GPU launch configuration for the block dims early, this is not 902 // subject to IR inspection. 903 diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt, 904 std::nullopt, std::nullopt, blockDims[0], blockDims[1], 905 blockDims[2]); 906 907 rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); 908 diag = 909 mapNestedForallToThreadsImpl(rewriter, transformOp, gpuLaunch, blockDims, 910 getWarpSize(), getSyncAfterDistribute()); 911 912 results.push_back(gpuLaunch.getOperation()); 913 return diag; 914 } 915 916 //===----------------------------------------------------------------------===// 917 // Transform op registration 918 //===----------------------------------------------------------------------===// 919 920 namespace { 921 /// Registers new ops and declares PDL as dependent dialect since the 922 /// additional ops are using PDL types for operands and results. 923 class GPUTransformDialectExtension 924 : public transform::TransformDialectExtension< 925 GPUTransformDialectExtension> { 926 public: 927 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension) 928 929 GPUTransformDialectExtension() { 930 declareGeneratedDialect<scf::SCFDialect>(); 931 declareGeneratedDialect<arith::ArithDialect>(); 932 declareGeneratedDialect<GPUDialect>(); 933 registerTransformOps< 934 #define GET_OP_LIST 935 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" 936 >(); 937 } 938 }; 939 } // namespace 940 941 #define GET_OP_CLASSES 942 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc" 943 944 void mlir::gpu::registerTransformDialectExtension(DialectRegistry ®istry) { 945 registry.addExtensions<GPUTransformDialectExtension>(); 946 } 947