1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the GPU kernel-related dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 14 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/IR/Attributes.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/BuiltinAttributes.h" 21 #include "mlir/IR/BuiltinOps.h" 22 #include "mlir/IR/BuiltinTypes.h" 23 #include "mlir/IR/Diagnostics.h" 24 #include "mlir/IR/DialectImplementation.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/IR/OpImplementation.h" 27 #include "mlir/IR/PatternMatch.h" 28 #include "mlir/IR/SymbolTable.h" 29 #include "mlir/IR/TypeUtilities.h" 30 #include "mlir/Interfaces/FunctionImplementation.h" 31 #include "mlir/Interfaces/SideEffectInterfaces.h" 32 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 33 #include "mlir/Transforms/InliningUtils.h" 34 #include "llvm/ADT/STLExtras.h" 35 #include "llvm/ADT/TypeSwitch.h" 36 #include "llvm/Support/CommandLine.h" 37 #include "llvm/Support/ErrorHandling.h" 38 #include "llvm/Support/StringSaver.h" 39 #include <cassert> 40 #include <numeric> 41 42 using namespace mlir; 43 using namespace mlir::gpu; 44 45 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc" 46 47 //===----------------------------------------------------------------------===// 48 // GPU Device Mapping Attributes 49 //===----------------------------------------------------------------------===// 50 51 int64_t GPUBlockMappingAttr::getMappingId() const { 52 return static_cast<int64_t>(getBlock()); 53 } 54 55 bool GPUBlockMappingAttr::isLinearMapping() const { 56 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0); 57 } 58 59 int64_t GPUBlockMappingAttr::getRelativeIndex() const { 60 return isLinearMapping() 61 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0) 62 : getMappingId(); 63 } 64 65 int64_t GPUWarpgroupMappingAttr::getMappingId() const { 66 return static_cast<int64_t>(getWarpgroup()); 67 } 68 69 bool GPUWarpgroupMappingAttr::isLinearMapping() const { 70 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0); 71 } 72 73 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const { 74 return isLinearMapping() 75 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0) 76 : getMappingId(); 77 } 78 79 int64_t GPUWarpMappingAttr::getMappingId() const { 80 return static_cast<int64_t>(getWarp()); 81 } 82 83 bool GPUWarpMappingAttr::isLinearMapping() const { 84 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0); 85 } 86 87 int64_t GPUWarpMappingAttr::getRelativeIndex() const { 88 return isLinearMapping() 89 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0) 90 : getMappingId(); 91 } 92 93 int64_t GPUThreadMappingAttr::getMappingId() const { 94 return static_cast<int64_t>(getThread()); 95 } 96 97 bool GPUThreadMappingAttr::isLinearMapping() const { 98 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0); 99 } 100 101 int64_t GPUThreadMappingAttr::getRelativeIndex() const { 102 return isLinearMapping() 103 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0) 104 : getMappingId(); 105 } 106 107 int64_t GPUMemorySpaceMappingAttr::getMappingId() const { 108 return static_cast<int64_t>(getAddressSpace()); 109 } 110 111 bool GPUMemorySpaceMappingAttr::isLinearMapping() const { 112 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping"); 113 } 114 115 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const { 116 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index"); 117 } 118 119 //===----------------------------------------------------------------------===// 120 // MMAMatrixType 121 //===----------------------------------------------------------------------===// 122 123 MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType, 124 StringRef operand) { 125 return Base::get(elementType.getContext(), shape, elementType, operand); 126 } 127 128 MMAMatrixType 129 MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError, 130 ArrayRef<int64_t> shape, Type elementType, 131 StringRef operand) { 132 return Base::getChecked(emitError, elementType.getContext(), shape, 133 elementType, operand); 134 } 135 136 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; } 137 138 ArrayRef<int64_t> MMAMatrixType::getShape() const { 139 return getImpl()->getShape(); 140 } 141 142 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; } 143 144 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } 145 146 bool MMAMatrixType::isValidElementType(Type elementType) { 147 return elementType.isF16() || elementType.isF32() || 148 elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) || 149 elementType.isInteger(32); 150 } 151 152 LogicalResult 153 MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, 154 ArrayRef<int64_t> shape, Type elementType, 155 StringRef operand) { 156 if (operand != "AOp" && operand != "BOp" && operand != "COp") 157 return emitError() << "operand expected to be one of AOp, BOp or COp"; 158 159 if (shape.size() != 2) 160 return emitError() << "MMAMatrixType must have exactly two dimensions"; 161 162 if (!MMAMatrixType::isValidElementType(elementType)) 163 return emitError() 164 << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32"; 165 166 return success(); 167 } 168 169 //===----------------------------------------------------------------------===// 170 // GPUDialect 171 //===----------------------------------------------------------------------===// 172 173 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) { 174 if (!memorySpace) 175 return false; 176 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace)) 177 return gpuAttr.getValue() == getWorkgroupAddressSpace(); 178 return false; 179 } 180 181 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) { 182 Attribute memorySpace = type.getMemorySpace(); 183 return isWorkgroupMemoryAddressSpace(memorySpace); 184 } 185 186 bool GPUDialect::isKernel(Operation *op) { 187 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 188 return static_cast<bool>(isKernelAttr); 189 } 190 191 namespace { 192 /// This class defines the interface for handling inlining with gpu 193 /// operations. 194 struct GPUInlinerInterface : public DialectInlinerInterface { 195 using DialectInlinerInterface::DialectInlinerInterface; 196 197 /// All gpu dialect ops can be inlined. 198 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { 199 return true; 200 } 201 }; 202 } // namespace 203 204 void GPUDialect::initialize() { 205 addTypes<AsyncTokenType>(); 206 addTypes<MMAMatrixType>(); 207 addTypes<SparseDnTensorHandleType>(); 208 addTypes<SparseSpMatHandleType>(); 209 addTypes<SparseSpGEMMOpHandleType>(); 210 addOperations< 211 #define GET_OP_LIST 212 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc" 213 >(); 214 addAttributes< 215 #define GET_ATTRDEF_LIST 216 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc" 217 >(); 218 addInterfaces<GPUInlinerInterface>(); 219 declarePromisedInterface<bufferization::BufferDeallocationOpInterface, 220 TerminatorOp>(); 221 declarePromisedInterfaces< 222 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp, 223 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp, 224 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>(); 225 } 226 227 static std::string getSparseHandleKeyword(SparseHandleKind kind) { 228 switch (kind) { 229 case SparseHandleKind::DnTensor: 230 return "sparse.dntensor_handle"; 231 case SparseHandleKind::SpMat: 232 return "sparse.spmat_handle"; 233 case SparseHandleKind::SpGEMMOp: 234 return "sparse.spgemmop_handle"; 235 } 236 llvm_unreachable("unknown sparse handle kind"); 237 return ""; 238 } 239 240 Type GPUDialect::parseType(DialectAsmParser &parser) const { 241 // Parse the main keyword for the type. 242 StringRef keyword; 243 if (parser.parseKeyword(&keyword)) 244 return Type(); 245 MLIRContext *context = getContext(); 246 247 // Handle 'async token' types. 248 if (keyword == "async.token") 249 return AsyncTokenType::get(context); 250 251 if (keyword == "mma_matrix") { 252 SMLoc beginLoc = parser.getNameLoc(); 253 254 // Parse '<'. 255 if (parser.parseLess()) 256 return nullptr; 257 258 // Parse the size and elementType. 259 SmallVector<int64_t> shape; 260 Type elementType; 261 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) || 262 parser.parseType(elementType)) 263 return nullptr; 264 265 // Parse ',' 266 if (parser.parseComma()) 267 return nullptr; 268 269 // Parse operand. 270 std::string operand; 271 if (failed(parser.parseOptionalString(&operand))) 272 return nullptr; 273 274 // Parse '>'. 275 if (parser.parseGreater()) 276 return nullptr; 277 278 return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn( 279 parser.getEncodedSourceLoc(beginLoc)), 280 shape, elementType, operand); 281 } 282 283 if (keyword == getSparseHandleKeyword(SparseHandleKind::DnTensor)) 284 return SparseDnTensorHandleType::get(context); 285 if (keyword == getSparseHandleKeyword(SparseHandleKind::SpMat)) 286 return SparseSpMatHandleType::get(context); 287 if (keyword == getSparseHandleKeyword(SparseHandleKind::SpGEMMOp)) 288 return SparseSpGEMMOpHandleType::get(context); 289 290 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword); 291 return Type(); 292 } 293 // TODO: print refined type here. Notice that should be corresponding to the 294 // parser 295 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { 296 TypeSwitch<Type>(type) 297 .Case<AsyncTokenType>([&](Type) { os << "async.token"; }) 298 .Case<SparseDnTensorHandleType>([&](Type) { 299 os << getSparseHandleKeyword(SparseHandleKind::DnTensor); 300 }) 301 .Case<SparseSpMatHandleType>( 302 [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::SpMat); }) 303 .Case<SparseSpGEMMOpHandleType>([&](Type) { 304 os << getSparseHandleKeyword(SparseHandleKind::SpGEMMOp); 305 }) 306 .Case<MMAMatrixType>([&](MMAMatrixType fragTy) { 307 os << "mma_matrix<"; 308 auto shape = fragTy.getShape(); 309 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim) 310 os << *dim << 'x'; 311 os << shape.back() << 'x' << fragTy.getElementType(); 312 os << ", \"" << fragTy.getOperand() << "\"" << '>'; 313 }) 314 .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); 315 } 316 317 static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, 318 NamedAttribute attr) { 319 auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue()); 320 if (!array) 321 return op->emitOpError(Twine(attr.getName()) + 322 " must be a dense i32 array"); 323 if (array.size() != 3) 324 return op->emitOpError(Twine(attr.getName()) + 325 " must contain exactly 3 elements"); 326 return success(); 327 } 328 329 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, 330 NamedAttribute attr) { 331 if (attr.getName() == getKnownBlockSizeAttrHelper().getName()) 332 return verifyKnownLaunchSizeAttr(op, attr); 333 if (attr.getName() == getKnownGridSizeAttrHelper().getName()) 334 return verifyKnownLaunchSizeAttr(op, attr); 335 if (!llvm::isa<UnitAttr>(attr.getValue()) || 336 attr.getName() != getContainerModuleAttrName()) 337 return success(); 338 339 auto module = dyn_cast<ModuleOp>(op); 340 if (!module) 341 return op->emitError("expected '") 342 << getContainerModuleAttrName() << "' attribute to be attached to '" 343 << ModuleOp::getOperationName() << '\''; 344 345 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { 346 // Ignore launches that are nested more or less deep than functions in the 347 // module we are currently checking. 348 if (!launchOp->getParentOp() || 349 launchOp->getParentOp()->getParentOp() != module) 350 return success(); 351 352 // Ignore launch ops with missing attributes here. The errors will be 353 // reported by the verifiers of those ops. 354 if (!launchOp->getAttrOfType<SymbolRefAttr>( 355 LaunchFuncOp::getKernelAttrName(launchOp->getName()))) 356 return success(); 357 358 // Check that `launch_func` refers to a well-formed GPU kernel container. 359 StringAttr kernelContainerName = launchOp.getKernelModuleName(); 360 Operation *kernelContainer = module.lookupSymbol(kernelContainerName); 361 if (!kernelContainer) 362 return launchOp.emitOpError() 363 << "kernel container '" << kernelContainerName.getValue() 364 << "' is undefined"; 365 366 // If the container is a GPU binary op return success. 367 if (isa<BinaryOp>(kernelContainer)) 368 return success(); 369 370 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer); 371 if (!kernelModule) 372 return launchOp.emitOpError() 373 << "kernel module '" << kernelContainerName.getValue() 374 << "' is undefined"; 375 376 // Check that `launch_func` refers to a well-formed kernel function. 377 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr()); 378 if (!kernelFunc) 379 return launchOp.emitOpError("kernel function '") 380 << launchOp.getKernel() << "' is undefined"; 381 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc); 382 if (!kernelConvertedFunction) { 383 InFlightDiagnostic diag = launchOp.emitOpError() 384 << "referenced kernel '" << launchOp.getKernel() 385 << "' is not a function"; 386 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here"; 387 return diag; 388 } 389 390 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( 391 GPUDialect::getKernelFuncAttrName())) 392 return launchOp.emitOpError("kernel function is missing the '") 393 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 394 395 // TODO: If the kernel isn't a GPU function (which happens during separate 396 // compilation), do not check type correspondence as it would require the 397 // verifier to be aware of the type conversion. 398 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc); 399 if (!kernelGPUFunction) 400 return success(); 401 402 unsigned actualNumArguments = launchOp.getNumKernelOperands(); 403 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments(); 404 if (expectedNumArguments != actualNumArguments) 405 return launchOp.emitOpError("got ") 406 << actualNumArguments << " kernel operands but expected " 407 << expectedNumArguments; 408 409 auto functionType = kernelGPUFunction.getFunctionType(); 410 for (unsigned i = 0; i < expectedNumArguments; ++i) { 411 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) { 412 return launchOp.emitOpError("type of function argument ") 413 << i << " does not match"; 414 } 415 } 416 417 return success(); 418 }); 419 420 return walkResult.wasInterrupted() ? failure() : success(); 421 } 422 423 /// Parses an optional list of async operands with an optional leading keyword. 424 /// (`async`)? (`[` ssa-id-list `]`)? 425 /// 426 /// This method is used by the tablegen assembly format for async ops as well. 427 static ParseResult parseAsyncDependencies( 428 OpAsmParser &parser, Type &asyncTokenType, 429 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &asyncDependencies) { 430 auto loc = parser.getCurrentLocation(); 431 if (succeeded(parser.parseOptionalKeyword("async"))) { 432 if (parser.getNumResults() == 0) 433 return parser.emitError(loc, "needs to be named when marked 'async'"); 434 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>(); 435 } 436 return parser.parseOperandList(asyncDependencies, 437 OpAsmParser::Delimiter::OptionalSquare); 438 } 439 440 /// Prints optional async dependencies with its leading keyword. 441 /// (`async`)? (`[` ssa-id-list `]`)? 442 // Used by the tablegen assembly format for several async ops. 443 static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, 444 Type asyncTokenType, 445 OperandRange asyncDependencies) { 446 if (asyncTokenType) 447 printer << "async"; 448 if (asyncDependencies.empty()) 449 return; 450 if (asyncTokenType) 451 printer << ' '; 452 printer << '['; 453 llvm::interleaveComma(asyncDependencies, printer); 454 printer << ']'; 455 } 456 457 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp. 458 /// Parses a GPU function memory attribution. 459 /// 460 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? 461 /// (`private` `(` ssa-id-and-type-list `)`)? 462 /// 463 /// Note that this function parses only one of the two similar parts, with the 464 /// keyword provided as argument. 465 static ParseResult 466 parseAttributions(OpAsmParser &parser, StringRef keyword, 467 SmallVectorImpl<OpAsmParser::Argument> &args) { 468 // If we could not parse the keyword, just assume empty list and succeed. 469 if (failed(parser.parseOptionalKeyword(keyword))) 470 return success(); 471 472 return parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren, 473 /*allowType=*/true); 474 } 475 476 /// Prints a GPU function memory attribution. 477 static void printAttributions(OpAsmPrinter &p, StringRef keyword, 478 ArrayRef<BlockArgument> values) { 479 if (values.empty()) 480 return; 481 482 p << ' ' << keyword << '('; 483 llvm::interleaveComma( 484 values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); }); 485 p << ')'; 486 } 487 488 /// Verifies a GPU function memory attribution. 489 static LogicalResult verifyAttributions(Operation *op, 490 ArrayRef<BlockArgument> attributions, 491 gpu::AddressSpace memorySpace) { 492 for (Value v : attributions) { 493 auto type = llvm::dyn_cast<MemRefType>(v.getType()); 494 if (!type) 495 return op->emitOpError() << "expected memref type in attribution"; 496 497 // We can only verify the address space if it hasn't already been lowered 498 // from the AddressSpaceAttr to a target-specific numeric value. 499 auto addressSpace = 500 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace()); 501 if (!addressSpace) 502 continue; 503 if (addressSpace.getValue() != memorySpace) 504 return op->emitOpError() 505 << "expected memory space " << stringifyAddressSpace(memorySpace) 506 << " in attribution"; 507 } 508 return success(); 509 } 510 511 //===----------------------------------------------------------------------===// 512 // AllReduceOp 513 //===----------------------------------------------------------------------===// 514 515 static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, 516 Type resType) { 517 using Kind = gpu::AllReduceOperation; 518 if (llvm::is_contained( 519 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF}, 520 opName)) { 521 if (!isa<FloatType>(resType)) 522 return failure(); 523 } 524 525 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI, 526 Kind::AND, Kind::OR, Kind::XOR}, 527 opName)) { 528 if (!isa<IntegerType>(resType)) 529 return failure(); 530 } 531 532 return success(); 533 } 534 535 LogicalResult gpu::AllReduceOp::verifyRegions() { 536 if (getBody().empty() != getOp().has_value()) 537 return emitError("expected either an op attribute or a non-empty body"); 538 if (!getBody().empty()) { 539 if (getBody().getNumArguments() != 2) 540 return emitError("expected two region arguments"); 541 for (auto argument : getBody().getArguments()) { 542 if (argument.getType() != getType()) 543 return emitError("incorrect region argument type"); 544 } 545 unsigned yieldCount = 0; 546 for (Block &block : getBody()) { 547 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { 548 if (yield.getNumOperands() != 1) 549 return emitError("expected one gpu.yield operand"); 550 if (yield.getOperand(0).getType() != getType()) 551 return emitError("incorrect gpu.yield type"); 552 ++yieldCount; 553 } 554 } 555 if (yieldCount == 0) 556 return emitError("expected gpu.yield op in region"); 557 } else { 558 gpu::AllReduceOperation opName = *getOp(); 559 if (failed(verifyReduceOpAndType(opName, getType()))) { 560 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName) 561 << "` reduction operation is not compatible with type " 562 << getType(); 563 } 564 } 565 566 return success(); 567 } 568 569 static bool canMakeGroupOpUniform(Operation *op) { 570 auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp()); 571 if (!launchOp) 572 return false; 573 574 Region &body = launchOp.getBody(); 575 assert(!body.empty() && "Invalid region"); 576 577 // Only convert ops in gpu::launch entry block for now. 578 return op->getBlock() == &body.front(); 579 } 580 581 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) { 582 if (!getUniform() && canMakeGroupOpUniform(*this)) { 583 setUniform(true); 584 return getResult(); 585 } 586 587 return nullptr; 588 } 589 590 // TODO: Support optional custom attributes (without dialect prefix). 591 static ParseResult parseAllReduceOperation(AsmParser &parser, 592 AllReduceOperationAttr &attr) { 593 StringRef enumStr; 594 if (!parser.parseOptionalKeyword(&enumStr)) { 595 std::optional<AllReduceOperation> op = 596 gpu::symbolizeAllReduceOperation(enumStr); 597 if (!op) 598 return parser.emitError(parser.getCurrentLocation(), "invalid op kind"); 599 attr = AllReduceOperationAttr::get(parser.getContext(), *op); 600 } 601 return success(); 602 } 603 604 static void printAllReduceOperation(AsmPrinter &printer, Operation *op, 605 AllReduceOperationAttr attr) { 606 if (attr) 607 attr.print(printer); 608 } 609 610 //===----------------------------------------------------------------------===// 611 // SubgroupReduceOp 612 //===----------------------------------------------------------------------===// 613 614 LogicalResult gpu::SubgroupReduceOp::verify() { 615 Type elemType = getType(); 616 if (auto vecTy = dyn_cast<VectorType>(elemType)) { 617 if (vecTy.isScalable()) 618 return emitOpError() << "is not compatible with scalable vector types"; 619 620 elemType = vecTy.getElementType(); 621 } 622 623 gpu::AllReduceOperation opName = getOp(); 624 if (failed(verifyReduceOpAndType(opName, elemType))) { 625 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName) 626 << "` reduction operation is not compatible with type " 627 << getType(); 628 } 629 630 auto clusterSize = getClusterSize(); 631 if (clusterSize) { 632 uint32_t size = *clusterSize; 633 if (!llvm::isPowerOf2_32(size)) { 634 return emitOpError() << "cluster size " << size 635 << " is not a power of two"; 636 } 637 } 638 639 uint32_t stride = getClusterStride(); 640 if (stride != 1 && !clusterSize) { 641 return emitOpError() << "cluster stride can only be specified if cluster " 642 "size is specified"; 643 } 644 if (!llvm::isPowerOf2_32(stride)) { 645 return emitOpError() << "cluster stride " << stride 646 << " is not a power of two"; 647 } 648 649 return success(); 650 } 651 652 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) { 653 if (getClusterSize() == 1) 654 return getValue(); 655 656 if (!getUniform() && canMakeGroupOpUniform(*this)) { 657 setUniform(true); 658 return getResult(); 659 } 660 661 return nullptr; 662 } 663 664 //===----------------------------------------------------------------------===// 665 // AsyncOpInterface 666 //===----------------------------------------------------------------------===// 667 668 void gpu::addAsyncDependency(Operation *op, Value token) { 669 op->insertOperands(0, {token}); 670 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>()) 671 return; 672 auto attrName = 673 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(); 674 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName); 675 676 // Async dependencies is the only variadic operand. 677 if (!sizeAttr) 678 return; 679 680 SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef()); 681 ++sizes.front(); 682 op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes)); 683 } 684 685 //===----------------------------------------------------------------------===// 686 // LaunchOp 687 //===----------------------------------------------------------------------===// 688 689 void LaunchOp::build(OpBuilder &builder, OperationState &result, 690 Value gridSizeX, Value gridSizeY, Value gridSizeZ, 691 Value getBlockSizeX, Value getBlockSizeY, 692 Value getBlockSizeZ, Value dynamicSharedMemorySize, 693 Type asyncTokenType, ValueRange asyncDependencies, 694 TypeRange workgroupAttributions, 695 TypeRange privateAttributions, Value clusterSizeX, 696 Value clusterSizeY, Value clusterSizeZ) { 697 OpBuilder::InsertionGuard g(builder); 698 699 // Add a WorkGroup attribution attribute. This attribute is required to 700 // identify private attributions in the list of block argguments. 701 result.addAttribute(getNumWorkgroupAttributionsAttrName(), 702 builder.getI64IntegerAttr(workgroupAttributions.size())); 703 704 // Add Op operands. 705 result.addOperands(asyncDependencies); 706 if (asyncTokenType) 707 result.types.push_back(builder.getType<AsyncTokenType>()); 708 709 // Add grid and block sizes as op operands, followed by the data operands. 710 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX, 711 getBlockSizeY, getBlockSizeZ}); 712 if (clusterSizeX) 713 result.addOperands(clusterSizeX); 714 if (clusterSizeY) 715 result.addOperands(clusterSizeY); 716 if (clusterSizeZ) 717 result.addOperands(clusterSizeZ); 718 if (dynamicSharedMemorySize) 719 result.addOperands(dynamicSharedMemorySize); 720 721 // Create a kernel body region with kNumConfigRegionAttributes + N memory 722 // attributions, where the first kNumConfigRegionAttributes arguments have 723 // `index` type and the rest have the same types as the data operands. 724 Region *kernelRegion = result.addRegion(); 725 Block *body = builder.createBlock(kernelRegion); 726 // TODO: Allow passing in proper locations here. 727 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i) 728 body->addArgument(builder.getIndexType(), result.location); 729 // Add WorkGroup & Private attributions to the region arguments. 730 for (Type argTy : workgroupAttributions) 731 body->addArgument(argTy, result.location); 732 for (Type argTy : privateAttributions) 733 body->addArgument(argTy, result.location); 734 // Fill OperandSegmentSize Attribute. 735 SmallVector<int32_t, 11> segmentSizes(11, 1); 736 segmentSizes.front() = asyncDependencies.size(); 737 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0; 738 segmentSizes[7] = clusterSizeX ? 1 : 0; 739 segmentSizes[8] = clusterSizeY ? 1 : 0; 740 segmentSizes[9] = clusterSizeZ ? 1 : 0; 741 result.addAttribute(getOperandSegmentSizeAttr(), 742 builder.getDenseI32ArrayAttr(segmentSizes)); 743 } 744 745 KernelDim3 LaunchOp::getBlockIds() { 746 assert(!getBody().empty() && "LaunchOp body must not be empty."); 747 auto args = getBody().getArguments(); 748 return KernelDim3{args[0], args[1], args[2]}; 749 } 750 751 KernelDim3 LaunchOp::getThreadIds() { 752 assert(!getBody().empty() && "LaunchOp body must not be empty."); 753 auto args = getBody().getArguments(); 754 return KernelDim3{args[3], args[4], args[5]}; 755 } 756 757 KernelDim3 LaunchOp::getGridSize() { 758 assert(!getBody().empty() && "LaunchOp body must not be empty."); 759 auto args = getBody().getArguments(); 760 return KernelDim3{args[6], args[7], args[8]}; 761 } 762 763 KernelDim3 LaunchOp::getBlockSize() { 764 assert(!getBody().empty() && "LaunchOp body must not be empty."); 765 auto args = getBody().getArguments(); 766 return KernelDim3{args[9], args[10], args[11]}; 767 } 768 769 std::optional<KernelDim3> LaunchOp::getClusterIds() { 770 assert(!getBody().empty() && "LaunchOp body must not be empty."); 771 if (!hasClusterSize()) 772 return std::nullopt; 773 auto args = getBody().getArguments(); 774 return KernelDim3{args[12], args[13], args[14]}; 775 } 776 777 std::optional<KernelDim3> LaunchOp::getClusterSize() { 778 assert(!getBody().empty() && "LaunchOp body must not be empty."); 779 if (!hasClusterSize()) 780 return std::nullopt; 781 auto args = getBody().getArguments(); 782 return KernelDim3{args[15], args[16], args[17]}; 783 } 784 785 KernelDim3 LaunchOp::getGridSizeOperandValues() { 786 auto operands = getOperands().drop_front(getAsyncDependencies().size()); 787 return KernelDim3{operands[0], operands[1], operands[2]}; 788 } 789 790 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 791 auto operands = getOperands().drop_front(getAsyncDependencies().size()); 792 return KernelDim3{operands[3], operands[4], operands[5]}; 793 } 794 795 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() { 796 auto operands = getOperands().drop_front(getAsyncDependencies().size()); 797 if (!hasClusterSize()) 798 return std::nullopt; 799 return KernelDim3{operands[6], operands[7], operands[8]}; 800 } 801 802 LogicalResult LaunchOp::verify() { 803 if (!(hasClusterSize()) && 804 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ())) 805 return emitOpError() << "cluster size must be all present"; 806 return success(); 807 } 808 809 LogicalResult LaunchOp::verifyRegions() { 810 // Kernel launch takes kNumConfigOperands leading operands for grid/block 811 // sizes and transforms them into kNumConfigRegionAttributes region arguments 812 // for block/thread identifiers and grid/block sizes. 813 if (!getBody().empty()) { 814 if (getBody().getNumArguments() < 815 kNumConfigRegionAttributes + getNumWorkgroupAttributions()) 816 return emitOpError("unexpected number of region arguments"); 817 } 818 819 // Verify Attributions Address Spaces. 820 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), 821 GPUDialect::getWorkgroupAddressSpace())) || 822 failed(verifyAttributions(getOperation(), getPrivateAttributions(), 823 GPUDialect::getPrivateAddressSpace()))) 824 return failure(); 825 826 // Block terminators without successors are expected to exit the kernel region 827 // and must be `gpu.terminator`. 828 for (Block &block : getBody()) { 829 if (block.empty()) 830 continue; 831 if (block.back().getNumSuccessors() != 0) 832 continue; 833 if (!isa<gpu::TerminatorOp>(&block.back())) { 834 return block.back() 835 .emitError() 836 .append("expected '", gpu::TerminatorOp::getOperationName(), 837 "' or a terminator with successors") 838 .attachNote(getLoc()) 839 .append("in '", LaunchOp::getOperationName(), "' body region"); 840 } 841 } 842 843 if (getNumResults() == 0 && getAsyncToken()) 844 return emitOpError("needs to be named when async keyword is specified"); 845 846 return success(); 847 } 848 849 // Pretty-print the kernel grid/block size assignment as 850 // (%iter-x, %iter-y, %iter-z) in 851 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 852 // where %size-* and %iter-* will correspond to the body region arguments. 853 static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, 854 KernelDim3 operands, KernelDim3 ids) { 855 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in ("; 856 p << size.x << " = " << operands.x << ", "; 857 p << size.y << " = " << operands.y << ", "; 858 p << size.z << " = " << operands.z << ')'; 859 } 860 861 void LaunchOp::print(OpAsmPrinter &p) { 862 if (getAsyncToken()) { 863 p << " async"; 864 if (!getAsyncDependencies().empty()) 865 p << " [" << getAsyncDependencies() << ']'; 866 } 867 // Print the launch configuration. 868 if (hasClusterSize()) { 869 p << ' ' << getClustersKeyword(); 870 printSizeAssignment(p, getClusterSize().value(), 871 getClusterSizeOperandValues().value(), 872 getClusterIds().value()); 873 } 874 p << ' ' << getBlocksKeyword(); 875 printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(), 876 getBlockIds()); 877 p << ' ' << getThreadsKeyword(); 878 printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(), 879 getThreadIds()); 880 if (getDynamicSharedMemorySize()) 881 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' ' 882 << getDynamicSharedMemorySize(); 883 884 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions()); 885 printAttributions(p, getPrivateKeyword(), getPrivateAttributions()); 886 887 p << ' '; 888 889 p.printRegion(getBody(), /*printEntryBlockArgs=*/false); 890 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ 891 LaunchOp::getOperandSegmentSizeAttr(), 892 getNumWorkgroupAttributionsAttrName()}); 893 } 894 895 // Parse the size assignment blocks for blocks and threads. These have the form 896 // (%region_arg, %region_arg, %region_arg) in 897 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 898 // where %region_arg are percent-identifiers for the region arguments to be 899 // introduced further (SSA defs), and %operand are percent-identifiers for the 900 // SSA value uses. 901 static ParseResult 902 parseSizeAssignment(OpAsmParser &parser, 903 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizes, 904 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionSizes, 905 MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) { 906 assert(indices.size() == 3 && "space for three indices expected"); 907 SmallVector<OpAsmParser::UnresolvedOperand, 3> args; 908 if (parser.parseOperandList(args, OpAsmParser::Delimiter::Paren, 909 /*allowResultNumber=*/false) || 910 parser.parseKeyword("in") || parser.parseLParen()) 911 return failure(); 912 std::move(args.begin(), args.end(), indices.begin()); 913 914 for (int i = 0; i < 3; ++i) { 915 if (i != 0 && parser.parseComma()) 916 return failure(); 917 if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) || 918 parser.parseEqual() || parser.parseOperand(sizes[i])) 919 return failure(); 920 } 921 922 return parser.parseRParen(); 923 } 924 925 /// Parses a Launch operation. 926 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)? 927 /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional) 928 /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 929 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment 930 /// memory-attribution 931 /// region attr-dict? 932 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 933 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) { 934 // Sizes of the grid and block. 935 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands> 936 sizes(LaunchOp::kNumConfigOperands); 937 938 // Actual (data) operands passed to the kernel. 939 SmallVector<OpAsmParser::UnresolvedOperand, 4> dataOperands; 940 941 // Region arguments to be created. 942 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs( 943 LaunchOp::kNumConfigRegionAttributes); 944 945 // Parse optional async dependencies. 946 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies; 947 Type asyncTokenType; 948 if (failed( 949 parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) || 950 parser.resolveOperands(asyncDependencies, asyncTokenType, 951 result.operands)) 952 return failure(); 953 if (parser.getNumResults() > 0) 954 result.types.push_back(asyncTokenType); 955 956 bool hasCluster = false; 957 if (succeeded( 958 parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) { 959 hasCluster = true; 960 sizes.resize(9); 961 regionArgs.resize(18); 962 } 963 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes); 964 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs); 965 966 // Last three segment assigns the cluster size. In the region argument 967 // list, this is last 6 arguments. 968 if (hasCluster) { 969 if (parseSizeAssignment(parser, sizesRef.drop_front(6), 970 regionArgsRef.slice(15, 3), 971 regionArgsRef.slice(12, 3))) 972 return failure(); 973 } 974 // Parse the size assignment segments: the first segment assigns grid sizes 975 // and defines values for block identifiers; the second segment assigns block 976 // sizes and defines values for thread identifiers. In the region argument 977 // list, identifiers precede sizes, and block-related values precede 978 // thread-related values. 979 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) || 980 parseSizeAssignment(parser, sizesRef.take_front(3), 981 regionArgsRef.slice(6, 3), 982 regionArgsRef.slice(0, 3)) || 983 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) || 984 parseSizeAssignment(parser, sizesRef.drop_front(3), 985 regionArgsRef.slice(9, 3), 986 regionArgsRef.slice(3, 3)) || 987 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), 988 result.operands)) 989 return failure(); 990 991 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize; 992 bool hasDynamicSharedMemorySize = false; 993 if (!parser.parseOptionalKeyword( 994 LaunchOp::getDynamicSharedMemorySizeKeyword())) { 995 hasDynamicSharedMemorySize = true; 996 if (parser.parseOperand(dynamicSharedMemorySize) || 997 parser.resolveOperand(dynamicSharedMemorySize, 998 parser.getBuilder().getI32Type(), 999 result.operands)) 1000 return failure(); 1001 } 1002 1003 // Create the region arguments, it has kNumConfigRegionAttributes arguments 1004 // that correspond to block/thread identifiers and grid/block sizes, all 1005 // having `index` type, a variadic number of WorkGroup Attributions and 1006 // a variadic number of Private Attributions. The number of WorkGroup 1007 // Attributions is stored in the attr with name: 1008 // LaunchOp::getNumWorkgroupAttributionsAttrName(). 1009 Type index = parser.getBuilder().getIndexType(); 1010 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes( 1011 LaunchOp::kNumConfigRegionAttributes + 6, index); 1012 1013 SmallVector<OpAsmParser::Argument> regionArguments; 1014 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) { 1015 OpAsmParser::Argument arg; 1016 arg.ssaName = std::get<0>(ssaValueAndType); 1017 arg.type = std::get<1>(ssaValueAndType); 1018 regionArguments.push_back(arg); 1019 } 1020 1021 Builder &builder = parser.getBuilder(); 1022 // Parse workgroup memory attributions. 1023 if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(), 1024 regionArguments))) 1025 return failure(); 1026 1027 // Store the number of operands we just parsed as the number of workgroup 1028 // memory attributions. 1029 unsigned numWorkgroupAttrs = regionArguments.size() - 1030 LaunchOp::kNumConfigRegionAttributes - 1031 (hasCluster ? 6 : 0); 1032 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(), 1033 builder.getI64IntegerAttr(numWorkgroupAttrs)); 1034 1035 // Parse private memory attributions. 1036 if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(), 1037 regionArguments))) 1038 return failure(); 1039 1040 // Introduce the body region and parse it. The region has 1041 // kNumConfigRegionAttributes arguments that correspond to 1042 // block/thread identifiers and grid/block sizes, all having `index` type. 1043 Region *body = result.addRegion(); 1044 if (parser.parseRegion(*body, regionArguments) || 1045 parser.parseOptionalAttrDict(result.attributes)) 1046 return failure(); 1047 1048 SmallVector<int32_t, 11> segmentSizes(11, 1); 1049 segmentSizes.front() = asyncDependencies.size(); 1050 1051 if (!hasCluster) { 1052 segmentSizes[7] = 0; 1053 segmentSizes[8] = 0; 1054 segmentSizes[9] = 0; 1055 } 1056 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0; 1057 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(), 1058 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes)); 1059 return success(); 1060 } 1061 1062 /// Simplify the gpu.launch when the range of a thread or block ID is 1063 /// trivially known to be one. 1064 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> { 1065 using OpRewritePattern<LaunchOp>::OpRewritePattern; 1066 LogicalResult matchAndRewrite(LaunchOp op, 1067 PatternRewriter &rewriter) const override { 1068 // If the range implies a single value for `id`, replace `id`'s uses by 1069 // zero. 1070 Value zero; 1071 bool simplified = false; 1072 auto constPropIdUses = [&](Value id, Value size) { 1073 // Check if size is trivially one. 1074 if (!matchPattern(size, m_One())) 1075 return; 1076 if (id.getUses().empty()) 1077 return; 1078 if (!simplified) { 1079 // Create a zero value the first time. 1080 OpBuilder::InsertionGuard guard(rewriter); 1081 rewriter.setInsertionPointToStart(&op.getBody().front()); 1082 zero = 1083 rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0); 1084 } 1085 rewriter.replaceAllUsesWith(id, zero); 1086 simplified = true; 1087 }; 1088 constPropIdUses(op.getBlockIds().x, op.getGridSizeX()); 1089 constPropIdUses(op.getBlockIds().y, op.getGridSizeY()); 1090 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ()); 1091 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX()); 1092 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY()); 1093 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ()); 1094 1095 return success(simplified); 1096 } 1097 }; 1098 1099 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites, 1100 MLIRContext *context) { 1101 rewrites.add<FoldLaunchArguments>(context); 1102 } 1103 1104 /// Adds a new block argument that corresponds to buffers located in 1105 /// workgroup memory. 1106 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) { 1107 auto attrName = getNumWorkgroupAttributionsAttrName(); 1108 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName); 1109 (*this)->setAttr(attrName, 1110 IntegerAttr::get(attr.getType(), attr.getValue() + 1)); 1111 return getBody().insertArgument( 1112 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc); 1113 } 1114 1115 /// Adds a new block argument that corresponds to buffers located in 1116 /// private memory. 1117 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) { 1118 // Buffers on the private memory always come after buffers on the workgroup 1119 // memory. 1120 return getBody().addArgument(type, loc); 1121 } 1122 1123 //===----------------------------------------------------------------------===// 1124 // LaunchFuncOp 1125 //===----------------------------------------------------------------------===// 1126 1127 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, 1128 SymbolRefAttr kernelSymbol, KernelDim3 gridSize, 1129 KernelDim3 getBlockSize, Value dynamicSharedMemorySize, 1130 ValueRange kernelOperands, Type asyncTokenType, 1131 ValueRange asyncDependencies, 1132 std::optional<KernelDim3> clusterSize) { 1133 assert(kernelSymbol.getNestedReferences().size() == 1 && 1134 "expected a symbol reference with a single nested reference"); 1135 result.addOperands(asyncDependencies); 1136 if (asyncTokenType) 1137 result.types.push_back(builder.getType<AsyncTokenType>()); 1138 1139 // Add grid and block sizes as op operands, followed by the data operands. 1140 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x, 1141 getBlockSize.y, getBlockSize.z}); 1142 if (clusterSize.has_value()) 1143 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z}); 1144 if (dynamicSharedMemorySize) 1145 result.addOperands(dynamicSharedMemorySize); 1146 result.addOperands(kernelOperands); 1147 1148 Properties &prop = result.getOrAddProperties<Properties>(); 1149 prop.kernel = kernelSymbol; 1150 size_t segmentSizesLen = std::size(prop.operandSegmentSizes); 1151 // Initialize the segment sizes to 1. 1152 for (auto &sz : prop.operandSegmentSizes) 1153 sz = 1; 1154 prop.operandSegmentSizes[0] = asyncDependencies.size(); 1155 if (!clusterSize.has_value()) { 1156 prop.operandSegmentSizes[segmentSizesLen - 4] = 0; 1157 prop.operandSegmentSizes[segmentSizesLen - 5] = 0; 1158 prop.operandSegmentSizes[segmentSizesLen - 6] = 0; 1159 } 1160 prop.operandSegmentSizes[segmentSizesLen - 3] = 1161 dynamicSharedMemorySize ? 1 : 0; 1162 prop.operandSegmentSizes[segmentSizesLen - 2] = 1163 static_cast<int32_t>(kernelOperands.size()); 1164 prop.operandSegmentSizes[segmentSizesLen - 1] = 0; 1165 } 1166 1167 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, 1168 GPUFuncOp kernelFunc, KernelDim3 gridSize, 1169 KernelDim3 getBlockSize, Value dynamicSharedMemorySize, 1170 ValueRange kernelOperands, Type asyncTokenType, 1171 ValueRange asyncDependencies, 1172 std::optional<KernelDim3> clusterSize) { 1173 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>(); 1174 auto kernelSymbol = 1175 SymbolRefAttr::get(kernelModule.getNameAttr(), 1176 {SymbolRefAttr::get(kernelFunc.getNameAttr())}); 1177 build(builder, result, kernelSymbol, gridSize, getBlockSize, 1178 dynamicSharedMemorySize, kernelOperands, asyncTokenType, 1179 asyncDependencies, clusterSize); 1180 } 1181 1182 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, 1183 SymbolRefAttr kernel, KernelDim3 gridSize, 1184 KernelDim3 getBlockSize, Value dynamicSharedMemorySize, 1185 ValueRange kernelOperands, Value asyncObject, 1186 std::optional<KernelDim3> clusterSize) { 1187 // Add grid and block sizes as op operands, followed by the data operands. 1188 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x, 1189 getBlockSize.y, getBlockSize.z}); 1190 if (clusterSize.has_value()) 1191 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z}); 1192 if (dynamicSharedMemorySize) 1193 result.addOperands(dynamicSharedMemorySize); 1194 result.addOperands(kernelOperands); 1195 if (asyncObject) 1196 result.addOperands(asyncObject); 1197 Properties &prop = result.getOrAddProperties<Properties>(); 1198 prop.kernel = kernel; 1199 size_t segmentSizesLen = std::size(prop.operandSegmentSizes); 1200 // Initialize the segment sizes to 1. 1201 for (auto &sz : prop.operandSegmentSizes) 1202 sz = 1; 1203 prop.operandSegmentSizes[0] = 0; 1204 if (!clusterSize.has_value()) { 1205 prop.operandSegmentSizes[segmentSizesLen - 4] = 0; 1206 prop.operandSegmentSizes[segmentSizesLen - 5] = 0; 1207 prop.operandSegmentSizes[segmentSizesLen - 6] = 0; 1208 } 1209 prop.operandSegmentSizes[segmentSizesLen - 3] = 1210 dynamicSharedMemorySize ? 1 : 0; 1211 prop.operandSegmentSizes[segmentSizesLen - 2] = 1212 static_cast<int32_t>(kernelOperands.size()); 1213 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0; 1214 } 1215 1216 StringAttr LaunchFuncOp::getKernelModuleName() { 1217 return getKernel().getRootReference(); 1218 } 1219 1220 StringAttr LaunchFuncOp::getKernelName() { 1221 return getKernel().getLeafReference(); 1222 } 1223 1224 unsigned LaunchFuncOp::getNumKernelOperands() { 1225 return getKernelOperands().size(); 1226 } 1227 1228 Value LaunchFuncOp::getKernelOperand(unsigned i) { 1229 return getKernelOperands()[i]; 1230 } 1231 1232 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 1233 auto operands = getOperands().drop_front(getAsyncDependencies().size()); 1234 return KernelDim3{operands[0], operands[1], operands[2]}; 1235 } 1236 1237 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 1238 auto operands = getOperands().drop_front(getAsyncDependencies().size()); 1239 return KernelDim3{operands[3], operands[4], operands[5]}; 1240 } 1241 1242 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() { 1243 assert(hasClusterSize() && 1244 "cluster size is not set, check hasClusterSize() first"); 1245 auto operands = getOperands().drop_front(getAsyncDependencies().size()); 1246 return KernelDim3{operands[6], operands[7], operands[8]}; 1247 } 1248 1249 LogicalResult LaunchFuncOp::verify() { 1250 auto module = (*this)->getParentOfType<ModuleOp>(); 1251 if (!module) 1252 return emitOpError("expected to belong to a module"); 1253 1254 if (!module->getAttrOfType<UnitAttr>( 1255 GPUDialect::getContainerModuleAttrName())) 1256 return emitOpError("expected the closest surrounding module to have the '" + 1257 GPUDialect::getContainerModuleAttrName() + 1258 "' attribute"); 1259 1260 if (hasClusterSize()) { 1261 if (getClusterSizeY().getType() != getClusterSizeX().getType() || 1262 getClusterSizeZ().getType() != getClusterSizeX().getType()) 1263 return emitOpError() 1264 << "expects types of the cluster dimensions must be the same"; 1265 } 1266 1267 return success(); 1268 } 1269 1270 static ParseResult 1271 parseLaunchDimType(OpAsmParser &parser, Type &dimTy, 1272 std::optional<OpAsmParser::UnresolvedOperand> clusterValue, 1273 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) { 1274 if (succeeded(parser.parseOptionalColon())) { 1275 if (parser.parseType(dimTy)) 1276 return failure(); 1277 } else { 1278 dimTy = IndexType::get(parser.getContext()); 1279 } 1280 if (clusterValue.has_value()) { 1281 clusterXTy = clusterYTy = clusterZTy = dimTy; 1282 } 1283 return success(); 1284 } 1285 1286 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, 1287 Value clusterValue, Type clusterXTy, 1288 Type clusterYTy, Type clusterZTy) { 1289 if (!dimTy.isIndex()) 1290 printer << ": " << dimTy; 1291 } 1292 1293 static ParseResult parseLaunchFuncOperands( 1294 OpAsmParser &parser, 1295 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames, 1296 SmallVectorImpl<Type> &argTypes) { 1297 if (parser.parseOptionalKeyword("args")) 1298 return success(); 1299 1300 auto parseElement = [&]() -> ParseResult { 1301 return failure(parser.parseOperand(argNames.emplace_back()) || 1302 parser.parseColonType(argTypes.emplace_back())); 1303 }; 1304 1305 return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, 1306 parseElement, " in argument list"); 1307 } 1308 1309 static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, 1310 OperandRange operands, TypeRange types) { 1311 if (operands.empty()) 1312 return; 1313 printer << "args("; 1314 llvm::interleaveComma(llvm::zip(operands, types), printer, 1315 [&](const auto &pair) { 1316 printer.printOperand(std::get<0>(pair)); 1317 printer << " : "; 1318 printer.printType(std::get<1>(pair)); 1319 }); 1320 printer << ")"; 1321 } 1322 1323 //===----------------------------------------------------------------------===// 1324 // ShuffleOp 1325 //===----------------------------------------------------------------------===// 1326 1327 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value, 1328 int32_t offset, int32_t width, ShuffleMode mode) { 1329 build(builder, result, value, 1330 builder.create<arith::ConstantOp>(result.location, 1331 builder.getI32IntegerAttr(offset)), 1332 builder.create<arith::ConstantOp>(result.location, 1333 builder.getI32IntegerAttr(width)), 1334 mode); 1335 } 1336 1337 //===----------------------------------------------------------------------===// 1338 // BarrierOp 1339 //===----------------------------------------------------------------------===// 1340 1341 namespace { 1342 1343 /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized! 1344 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, 1345 PatternRewriter &rewriter) { 1346 if (isa_and_nonnull<BarrierOp>(op->getNextNode())) { 1347 rewriter.eraseOp(op); 1348 return success(); 1349 } 1350 return failure(); 1351 } 1352 1353 } // end anonymous namespace 1354 1355 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results, 1356 MLIRContext *context) { 1357 results.add(eraseRedundantGpuBarrierOps); 1358 } 1359 1360 //===----------------------------------------------------------------------===// 1361 // GPUFuncOp 1362 //===----------------------------------------------------------------------===// 1363 1364 /// Adds a new block argument that corresponds to buffers located in 1365 /// workgroup memory. 1366 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) { 1367 auto attrName = getNumWorkgroupAttributionsAttrName(); 1368 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName); 1369 (*this)->setAttr(attrName, 1370 IntegerAttr::get(attr.getType(), attr.getValue() + 1)); 1371 return getBody().insertArgument( 1372 getFunctionType().getNumInputs() + attr.getInt(), type, loc); 1373 } 1374 1375 /// Adds a new block argument that corresponds to buffers located in 1376 /// private memory. 1377 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) { 1378 // Buffers on the private memory always come after buffers on the workgroup 1379 // memory. 1380 return getBody().addArgument(type, loc); 1381 } 1382 1383 void GPUFuncOp::build(OpBuilder &builder, OperationState &result, 1384 StringRef name, FunctionType type, 1385 TypeRange workgroupAttributions, 1386 TypeRange privateAttributions, 1387 ArrayRef<NamedAttribute> attrs) { 1388 OpBuilder::InsertionGuard g(builder); 1389 1390 result.addAttribute(SymbolTable::getSymbolAttrName(), 1391 builder.getStringAttr(name)); 1392 result.addAttribute(getFunctionTypeAttrName(result.name), 1393 TypeAttr::get(type)); 1394 result.addAttribute(getNumWorkgroupAttributionsAttrName(), 1395 builder.getI64IntegerAttr(workgroupAttributions.size())); 1396 result.addAttributes(attrs); 1397 Region *body = result.addRegion(); 1398 Block *entryBlock = builder.createBlock(body); 1399 1400 // TODO: Allow passing in proper locations here. 1401 for (Type argTy : type.getInputs()) 1402 entryBlock->addArgument(argTy, result.location); 1403 for (Type argTy : workgroupAttributions) 1404 entryBlock->addArgument(argTy, result.location); 1405 for (Type argTy : privateAttributions) 1406 entryBlock->addArgument(argTy, result.location); 1407 } 1408 1409 /// Parses a GPU function memory attribution. 1410 /// 1411 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? 1412 /// (`private` `(` ssa-id-and-type-list `)`)? 1413 /// 1414 /// Note that this function parses only one of the two similar parts, with the 1415 /// keyword provided as argument. 1416 static ParseResult 1417 parseAttributions(OpAsmParser &parser, StringRef keyword, 1418 SmallVectorImpl<OpAsmParser::Argument> &args, 1419 Attribute &attributionAttrs) { 1420 // If we could not parse the keyword, just assume empty list and succeed. 1421 if (failed(parser.parseOptionalKeyword(keyword))) 1422 return success(); 1423 1424 size_t existingArgs = args.size(); 1425 ParseResult result = 1426 parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren, 1427 /*allowType=*/true, /*allowAttrs=*/true); 1428 if (failed(result)) 1429 return result; 1430 1431 bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs), 1432 [](const OpAsmParser::Argument &arg) -> bool { 1433 return arg.attrs && !arg.attrs.empty(); 1434 }); 1435 if (!hadAttrs) { 1436 attributionAttrs = nullptr; 1437 return result; 1438 } 1439 1440 Builder &builder = parser.getBuilder(); 1441 SmallVector<Attribute> attributionAttrsVec; 1442 for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) { 1443 if (!argument.attrs) 1444 attributionAttrsVec.push_back(builder.getDictionaryAttr({})); 1445 else 1446 attributionAttrsVec.push_back(argument.attrs); 1447 } 1448 attributionAttrs = builder.getArrayAttr(attributionAttrsVec); 1449 return result; 1450 } 1451 1452 /// Parses a GPU function. 1453 /// 1454 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)` 1455 /// (`->` function-result-list)? memory-attribution `kernel`? 1456 /// function-attributes? region 1457 ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { 1458 SmallVector<OpAsmParser::Argument> entryArgs; 1459 SmallVector<DictionaryAttr> resultAttrs; 1460 SmallVector<Type> resultTypes; 1461 bool isVariadic; 1462 1463 // Parse the function name. 1464 StringAttr nameAttr; 1465 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 1466 result.attributes)) 1467 return failure(); 1468 1469 auto signatureLocation = parser.getCurrentLocation(); 1470 if (failed(function_interface_impl::parseFunctionSignature( 1471 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, 1472 resultAttrs))) 1473 return failure(); 1474 1475 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty()) 1476 return parser.emitError(signatureLocation) 1477 << "gpu.func requires named arguments"; 1478 1479 // Construct the function type. More types will be added to the region, but 1480 // not to the function type. 1481 Builder &builder = parser.getBuilder(); 1482 1483 SmallVector<Type> argTypes; 1484 for (auto &arg : entryArgs) 1485 argTypes.push_back(arg.type); 1486 auto type = builder.getFunctionType(argTypes, resultTypes); 1487 result.addAttribute(getFunctionTypeAttrName(result.name), 1488 TypeAttr::get(type)); 1489 1490 function_interface_impl::addArgAndResultAttrs( 1491 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), 1492 getResAttrsAttrName(result.name)); 1493 1494 Attribute workgroupAttributionAttrs; 1495 // Parse workgroup memory attributions. 1496 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), 1497 entryArgs, workgroupAttributionAttrs))) 1498 return failure(); 1499 1500 // Store the number of operands we just parsed as the number of workgroup 1501 // memory attributions. 1502 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs(); 1503 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(), 1504 builder.getI64IntegerAttr(numWorkgroupAttrs)); 1505 if (workgroupAttributionAttrs) 1506 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name), 1507 workgroupAttributionAttrs); 1508 1509 Attribute privateAttributionAttrs; 1510 // Parse private memory attributions. 1511 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), 1512 entryArgs, privateAttributionAttrs))) 1513 return failure(); 1514 if (privateAttributionAttrs) 1515 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name), 1516 privateAttributionAttrs); 1517 1518 // Parse the kernel attribute if present. 1519 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword()))) 1520 result.addAttribute(GPUDialect::getKernelFuncAttrName(), 1521 builder.getUnitAttr()); 1522 1523 // Parse attributes. 1524 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 1525 return failure(); 1526 1527 // Parse the region. If no argument names were provided, take all names 1528 // (including those of attributions) from the entry block. 1529 auto *body = result.addRegion(); 1530 return parser.parseRegion(*body, entryArgs); 1531 } 1532 1533 static void printAttributions(OpAsmPrinter &p, StringRef keyword, 1534 ArrayRef<BlockArgument> values, 1535 ArrayAttr attributes) { 1536 if (values.empty()) 1537 return; 1538 1539 p << ' ' << keyword << '('; 1540 llvm::interleaveComma( 1541 llvm::enumerate(values), p, [&p, attributes](auto pair) { 1542 BlockArgument v = pair.value(); 1543 p << v << " : " << v.getType(); 1544 1545 size_t attributionIndex = pair.index(); 1546 DictionaryAttr attrs; 1547 if (attributes && attributionIndex < attributes.size()) 1548 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]); 1549 if (attrs) 1550 p.printOptionalAttrDict(attrs.getValue()); 1551 }); 1552 p << ')'; 1553 } 1554 1555 void GPUFuncOp::print(OpAsmPrinter &p) { 1556 p << ' '; 1557 p.printSymbolName(getName()); 1558 1559 FunctionType type = getFunctionType(); 1560 function_interface_impl::printFunctionSignature(p, *this, type.getInputs(), 1561 /*isVariadic=*/false, 1562 type.getResults()); 1563 1564 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(), 1565 getWorkgroupAttribAttrs().value_or(nullptr)); 1566 printAttributions(p, getPrivateKeyword(), getPrivateAttributions(), 1567 getPrivateAttribAttrs().value_or(nullptr)); 1568 if (isKernel()) 1569 p << ' ' << getKernelKeyword(); 1570 1571 function_interface_impl::printFunctionAttributes( 1572 p, *this, 1573 {getNumWorkgroupAttributionsAttrName(), 1574 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(), 1575 getArgAttrsAttrName(), getResAttrsAttrName(), 1576 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()}); 1577 p << ' '; 1578 p.printRegion(getBody(), /*printEntryBlockArgs=*/false); 1579 } 1580 1581 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, 1582 StringAttr attrName) { 1583 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName)); 1584 if (!allAttrs || index >= allAttrs.size()) 1585 return DictionaryAttr(); 1586 return llvm::cast<DictionaryAttr>(allAttrs[index]); 1587 } 1588 1589 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) { 1590 return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName()); 1591 } 1592 1593 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) { 1594 return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName()); 1595 } 1596 1597 static void setAttributionAttrs(GPUFuncOp op, unsigned index, 1598 DictionaryAttr value, StringAttr attrName) { 1599 MLIRContext *ctx = op.getContext(); 1600 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName)); 1601 SmallVector<Attribute> elements; 1602 if (allAttrs) 1603 elements.append(allAttrs.begin(), allAttrs.end()); 1604 while (elements.size() <= index) 1605 elements.push_back(DictionaryAttr::get(ctx)); 1606 if (!value) 1607 elements[index] = DictionaryAttr::get(ctx); 1608 else 1609 elements[index] = value; 1610 ArrayAttr newValue = ArrayAttr::get(ctx, elements); 1611 op->setAttr(attrName, newValue); 1612 } 1613 1614 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index, 1615 DictionaryAttr value) { 1616 setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName()); 1617 } 1618 1619 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index, 1620 DictionaryAttr value) { 1621 setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName()); 1622 } 1623 1624 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, 1625 StringAttr name, StringAttr attrsName) { 1626 DictionaryAttr dict = getAttributionAttrs(op, index, attrsName); 1627 if (!dict) 1628 return Attribute(); 1629 return dict.get(name); 1630 } 1631 1632 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index, 1633 StringAttr name) { 1634 assert(index < getNumWorkgroupAttributions() && 1635 "index must map to a workgroup attribution"); 1636 return getAttributionAttr(*this, index, name, 1637 getWorkgroupAttribAttrsAttrName()); 1638 } 1639 1640 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index, 1641 StringAttr name) { 1642 assert(index < getNumPrivateAttributions() && 1643 "index must map to a private attribution"); 1644 return getAttributionAttr(*this, index, name, 1645 getPrivateAttribAttrsAttrName()); 1646 } 1647 1648 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, 1649 Attribute value, StringAttr attrsName) { 1650 MLIRContext *ctx = op.getContext(); 1651 SmallVector<NamedAttribute> elems; 1652 DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName); 1653 if (oldDict) 1654 elems.append(oldDict.getValue().begin(), oldDict.getValue().end()); 1655 1656 bool found = false; 1657 bool mustSort = true; 1658 for (unsigned i = 0, e = elems.size(); i < e; ++i) { 1659 if (elems[i].getName() == name) { 1660 found = true; 1661 if (!value) { 1662 std::swap(elems[i], elems[elems.size() - 1]); 1663 elems.pop_back(); 1664 } else { 1665 mustSort = false; 1666 elems[i] = NamedAttribute(elems[i].getName(), value); 1667 } 1668 break; 1669 } 1670 } 1671 if (!found) { 1672 if (!value) 1673 return; 1674 elems.emplace_back(name, value); 1675 } 1676 if (mustSort) { 1677 DictionaryAttr::sortInPlace(elems); 1678 } 1679 auto newDict = DictionaryAttr::getWithSorted(ctx, elems); 1680 setAttributionAttrs(op, index, newDict, attrsName); 1681 } 1682 1683 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name, 1684 Attribute value) { 1685 assert(index < getNumWorkgroupAttributions() && 1686 "index must map to a workgroup attribution"); 1687 setAttributionAttr(*this, index, name, value, 1688 getWorkgroupAttribAttrsAttrName()); 1689 } 1690 1691 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name, 1692 Attribute value) { 1693 assert(index < getNumPrivateAttributions() && 1694 "index must map to a private attribution"); 1695 setAttributionAttr(*this, index, name, value, 1696 getPrivateAttribAttrsAttrName()); 1697 } 1698 1699 LogicalResult GPUFuncOp::verifyType() { 1700 if (isKernel() && getFunctionType().getNumResults() != 0) 1701 return emitOpError() << "expected void return type for kernel function"; 1702 1703 return success(); 1704 } 1705 1706 /// Verifies the body of the function. 1707 LogicalResult GPUFuncOp::verifyBody() { 1708 if (empty()) 1709 return emitOpError() << "expected body with at least one block"; 1710 unsigned numFuncArguments = getNumArguments(); 1711 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); 1712 unsigned numBlockArguments = front().getNumArguments(); 1713 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) 1714 return emitOpError() << "expected at least " 1715 << numFuncArguments + numWorkgroupAttributions 1716 << " arguments to body region"; 1717 1718 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs(); 1719 for (unsigned i = 0; i < numFuncArguments; ++i) { 1720 Type blockArgType = front().getArgument(i).getType(); 1721 if (funcArgTypes[i] != blockArgType) 1722 return emitOpError() << "expected body region argument #" << i 1723 << " to be of type " << funcArgTypes[i] << ", got " 1724 << blockArgType; 1725 } 1726 1727 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), 1728 GPUDialect::getWorkgroupAddressSpace())) || 1729 failed(verifyAttributions(getOperation(), getPrivateAttributions(), 1730 GPUDialect::getPrivateAddressSpace()))) 1731 return failure(); 1732 1733 return success(); 1734 } 1735 1736 //===----------------------------------------------------------------------===// 1737 // ReturnOp 1738 //===----------------------------------------------------------------------===// 1739 1740 LogicalResult gpu::ReturnOp::verify() { 1741 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>(); 1742 1743 FunctionType funType = function.getFunctionType(); 1744 1745 if (funType.getNumResults() != getOperands().size()) 1746 return emitOpError() 1747 .append("expected ", funType.getNumResults(), " result operands") 1748 .attachNote(function.getLoc()) 1749 .append("return type declared here"); 1750 1751 for (const auto &pair : llvm::enumerate( 1752 llvm::zip(function.getFunctionType().getResults(), getOperands()))) { 1753 auto [type, operand] = pair.value(); 1754 if (type != operand.getType()) 1755 return emitOpError() << "unexpected type `" << operand.getType() 1756 << "' for operand #" << pair.index(); 1757 } 1758 return success(); 1759 } 1760 1761 //===----------------------------------------------------------------------===// 1762 // GPUModuleOp 1763 //===----------------------------------------------------------------------===// 1764 1765 void GPUModuleOp::build(OpBuilder &builder, OperationState &result, 1766 StringRef name, ArrayAttr targets, 1767 Attribute offloadingHandler) { 1768 result.addRegion()->emplaceBlock(); 1769 Properties &props = result.getOrAddProperties<Properties>(); 1770 if (targets) 1771 props.targets = targets; 1772 props.setSymName(builder.getStringAttr(name)); 1773 props.offloadingHandler = offloadingHandler; 1774 } 1775 1776 void GPUModuleOp::build(OpBuilder &builder, OperationState &result, 1777 StringRef name, ArrayRef<Attribute> targets, 1778 Attribute offloadingHandler) { 1779 build(builder, result, name, 1780 targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets), 1781 offloadingHandler); 1782 } 1783 1784 bool GPUModuleOp::hasTarget(Attribute target) { 1785 if (ArrayAttr targets = getTargetsAttr()) 1786 return llvm::count(targets.getValue(), target); 1787 return false; 1788 } 1789 1790 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) { 1791 ArrayAttr &targetsAttr = getProperties().targets; 1792 SmallVector<Attribute> targetsVector(targets); 1793 targetsAttr = ArrayAttr::get(getContext(), targetsVector); 1794 } 1795 1796 //===----------------------------------------------------------------------===// 1797 // GPUBinaryOp 1798 //===----------------------------------------------------------------------===// 1799 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name, 1800 Attribute offloadingHandler, ArrayAttr objects) { 1801 auto &properties = result.getOrAddProperties<Properties>(); 1802 result.attributes.push_back(builder.getNamedAttr( 1803 SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); 1804 properties.objects = objects; 1805 if (offloadingHandler) 1806 properties.offloadingHandler = offloadingHandler; 1807 else 1808 properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr); 1809 } 1810 1811 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name, 1812 Attribute offloadingHandler, ArrayRef<Attribute> objects) { 1813 build(builder, result, name, offloadingHandler, 1814 objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects)); 1815 } 1816 1817 static ParseResult parseOffloadingHandler(OpAsmParser &parser, 1818 Attribute &offloadingHandler) { 1819 if (succeeded(parser.parseOptionalLess())) { 1820 if (parser.parseAttribute(offloadingHandler)) 1821 return failure(); 1822 if (parser.parseGreater()) 1823 return failure(); 1824 } 1825 if (!offloadingHandler) 1826 offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr); 1827 return success(); 1828 } 1829 1830 static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, 1831 Attribute offloadingHandler) { 1832 if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr)) 1833 printer << '<' << offloadingHandler << '>'; 1834 } 1835 1836 //===----------------------------------------------------------------------===// 1837 // GPUMemcpyOp 1838 //===----------------------------------------------------------------------===// 1839 1840 LogicalResult MemcpyOp::verify() { 1841 auto srcType = getSrc().getType(); 1842 auto dstType = getDst().getType(); 1843 1844 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType)) 1845 return emitOpError("arguments have incompatible element type"); 1846 1847 if (failed(verifyCompatibleShape(srcType, dstType))) 1848 return emitOpError("arguments have incompatible shape"); 1849 1850 return success(); 1851 } 1852 1853 namespace { 1854 1855 /// Erases a common case of copy ops where a destination value is used only by 1856 /// the copy op, alloc and dealloc ops. 1857 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> { 1858 using OpRewritePattern<MemcpyOp>::OpRewritePattern; 1859 1860 LogicalResult matchAndRewrite(MemcpyOp op, 1861 PatternRewriter &rewriter) const override { 1862 Value dest = op.getDst(); 1863 Operation *destDefOp = dest.getDefiningOp(); 1864 // `dest` must be defined by an op having Allocate memory effect in order to 1865 // perform the folding. 1866 if (!destDefOp || 1867 !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest)) 1868 return failure(); 1869 // We can erase `op` iff `dest` has no other use apart from its 1870 // use by `op` and dealloc ops. 1871 if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) { 1872 return user != op && 1873 !hasSingleEffect<MemoryEffects::Free>(user, dest); 1874 })) 1875 return failure(); 1876 // We can perform the folding if and only if op has a single async 1877 // dependency and produces an async token as result, or if it does not have 1878 // any async dependency and does not produce any async token result. 1879 if (op.getAsyncDependencies().size() > 1 || 1880 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) || 1881 (!op.getAsyncDependencies().empty() && !op.getAsyncToken()))) 1882 return failure(); 1883 rewriter.replaceOp(op, op.getAsyncDependencies()); 1884 return success(); 1885 } 1886 }; 1887 1888 } // end anonymous namespace 1889 1890 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results, 1891 MLIRContext *context) { 1892 results.add<EraseTrivialCopyOp>(context); 1893 } 1894 1895 //===----------------------------------------------------------------------===// 1896 // GPU_SubgroupMmaLoadMatrixOp 1897 //===----------------------------------------------------------------------===// 1898 1899 LogicalResult SubgroupMmaLoadMatrixOp::verify() { 1900 auto srcType = getSrcMemref().getType(); 1901 auto resType = getRes().getType(); 1902 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType); 1903 auto operand = resMatrixType.getOperand(); 1904 auto srcMemrefType = llvm::cast<MemRefType>(srcType); 1905 1906 if (!srcMemrefType.isLastDimUnitStride()) 1907 return emitError( 1908 "expected source memref most minor dim must have unit stride"); 1909 1910 if (operand != "AOp" && operand != "BOp" && operand != "COp") 1911 return emitError("only AOp, BOp and COp can be loaded"); 1912 1913 return success(); 1914 } 1915 1916 //===----------------------------------------------------------------------===// 1917 // GPU_SubgroupMmaStoreMatrixOp 1918 //===----------------------------------------------------------------------===// 1919 1920 LogicalResult SubgroupMmaStoreMatrixOp::verify() { 1921 auto srcType = getSrc().getType(); 1922 auto dstType = getDstMemref().getType(); 1923 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType); 1924 auto dstMemrefType = llvm::cast<MemRefType>(dstType); 1925 1926 if (!dstMemrefType.isLastDimUnitStride()) 1927 return emitError( 1928 "expected destination memref most minor dim must have unit stride"); 1929 1930 if (srcMatrixType.getOperand() != "COp") 1931 return emitError( 1932 "expected the operand matrix being stored to have 'COp' operand type"); 1933 1934 return success(); 1935 } 1936 1937 //===----------------------------------------------------------------------===// 1938 // GPU_SubgroupMmaComputeOp 1939 //===----------------------------------------------------------------------===// 1940 1941 LogicalResult SubgroupMmaComputeOp::verify() { 1942 enum OperandMap { A, B, C }; 1943 SmallVector<MMAMatrixType, 3> opTypes; 1944 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType())); 1945 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType())); 1946 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType())); 1947 1948 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" || 1949 opTypes[C].getOperand() != "COp") 1950 return emitError("operands must be in the order AOp, BOp, COp"); 1951 1952 ArrayRef<int64_t> aShape, bShape, cShape; 1953 aShape = opTypes[A].getShape(); 1954 bShape = opTypes[B].getShape(); 1955 cShape = opTypes[C].getShape(); 1956 1957 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] || 1958 bShape[1] != cShape[1]) 1959 return emitError("operand shapes do not satisfy matmul constraints"); 1960 1961 return success(); 1962 } 1963 1964 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor, 1965 SmallVectorImpl<::mlir::OpFoldResult> &results) { 1966 return memref::foldMemRefCast(*this); 1967 } 1968 1969 LogicalResult MemsetOp::fold(FoldAdaptor adaptor, 1970 SmallVectorImpl<::mlir::OpFoldResult> &results) { 1971 return memref::foldMemRefCast(*this); 1972 } 1973 1974 //===----------------------------------------------------------------------===// 1975 // GPU_WaitOp 1976 //===----------------------------------------------------------------------===// 1977 1978 namespace { 1979 1980 /// Remove gpu.wait op use of gpu.wait op def without async dependencies. 1981 /// %t = gpu.wait async [] // No async dependencies. 1982 /// ... gpu.wait ... [%t, ...] // %t can be removed. 1983 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> { 1984 public: 1985 using OpRewritePattern::OpRewritePattern; 1986 1987 LogicalResult matchAndRewrite(WaitOp op, 1988 PatternRewriter &rewriter) const final { 1989 auto predicate = [](Value value) { 1990 auto waitOp = value.getDefiningOp<WaitOp>(); 1991 return waitOp && waitOp->getNumOperands() == 0; 1992 }; 1993 if (llvm::none_of(op.getAsyncDependencies(), predicate)) 1994 return failure(); 1995 SmallVector<Value> validOperands; 1996 for (Value operand : op->getOperands()) { 1997 if (predicate(operand)) 1998 continue; 1999 validOperands.push_back(operand); 2000 } 2001 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); }); 2002 return success(); 2003 } 2004 }; 2005 2006 /// Simplify trivial gpu.wait ops for the following patterns. 2007 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async 2008 /// dependencies). 2009 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with 2010 /// %t0. 2011 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async 2012 /// dependencies nor return any token. 2013 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> { 2014 public: 2015 using OpRewritePattern::OpRewritePattern; 2016 2017 LogicalResult matchAndRewrite(WaitOp op, 2018 PatternRewriter &rewriter) const final { 2019 // Erase gpu.wait ops that neither have any async dependencies nor return 2020 // any async token. 2021 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) { 2022 rewriter.eraseOp(op); 2023 return success(); 2024 } 2025 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op. 2026 if (llvm::hasSingleElement(op.getAsyncDependencies()) && 2027 op.getAsyncToken()) { 2028 rewriter.replaceOp(op, op.getAsyncDependencies()); 2029 return success(); 2030 } 2031 // Erase %t = gpu.wait async ... ops, where %t has no uses. 2032 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) { 2033 rewriter.eraseOp(op); 2034 return success(); 2035 } 2036 return failure(); 2037 } 2038 }; 2039 2040 } // end anonymous namespace 2041 2042 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results, 2043 MLIRContext *context) { 2044 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context); 2045 } 2046 2047 //===----------------------------------------------------------------------===// 2048 // GPU_AllocOp 2049 //===----------------------------------------------------------------------===// 2050 2051 LogicalResult AllocOp::verify() { 2052 auto memRefType = llvm::cast<MemRefType>(getMemref().getType()); 2053 2054 if (getDynamicSizes().size() != memRefType.getNumDynamicDims()) 2055 return emitOpError("dimension operand count does not equal memref " 2056 "dynamic dimension count"); 2057 2058 unsigned numSymbols = 0; 2059 if (!memRefType.getLayout().isIdentity()) 2060 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); 2061 if (getSymbolOperands().size() != numSymbols) { 2062 return emitOpError( 2063 "symbol operand count does not equal memref symbol count"); 2064 } 2065 2066 return success(); 2067 } 2068 2069 namespace { 2070 2071 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to 2072 /// `memref::AllocOp`. 2073 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> { 2074 using OpRewritePattern<memref::DimOp>::OpRewritePattern; 2075 2076 LogicalResult matchAndRewrite(memref::DimOp dimOp, 2077 PatternRewriter &rewriter) const override { 2078 std::optional<int64_t> index = dimOp.getConstantIndex(); 2079 if (!index) 2080 return failure(); 2081 2082 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType()); 2083 if (!memrefType || index.value() >= memrefType.getRank() || 2084 !memrefType.isDynamicDim(index.value())) 2085 return failure(); 2086 2087 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>(); 2088 if (!alloc) 2089 return failure(); 2090 2091 Value substituteOp = *(alloc.getDynamicSizes().begin() + 2092 memrefType.getDynamicDimIndex(index.value())); 2093 rewriter.replaceOp(dimOp, substituteOp); 2094 return success(); 2095 } 2096 }; 2097 2098 } // namespace 2099 2100 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, 2101 MLIRContext *context) { 2102 results.add<SimplifyDimOfAllocOp>(context); 2103 } 2104 2105 //===----------------------------------------------------------------------===// 2106 // GPU object attribute 2107 //===----------------------------------------------------------------------===// 2108 2109 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError, 2110 Attribute target, CompilationTarget format, 2111 StringAttr object, DictionaryAttr properties, 2112 KernelTableAttr kernels) { 2113 if (!target) 2114 return emitError() << "the target attribute cannot be null"; 2115 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) 2116 return success(); 2117 return emitError() << "the target attribute must implement or promise the " 2118 "`gpu::TargetAttrInterface`"; 2119 } 2120 2121 namespace { 2122 ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format, 2123 StringAttr &object) { 2124 std::optional<CompilationTarget> formatResult; 2125 StringRef enumKeyword; 2126 auto loc = odsParser.getCurrentLocation(); 2127 if (failed(odsParser.parseOptionalKeyword(&enumKeyword))) 2128 formatResult = CompilationTarget::Fatbin; 2129 if (!formatResult && 2130 (formatResult = 2131 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) && 2132 odsParser.parseEqual()) 2133 return odsParser.emitError(loc, "expected an equal sign"); 2134 if (!formatResult) 2135 return odsParser.emitError(loc, "expected keyword for GPU object format"); 2136 FailureOr<StringAttr> objectResult = 2137 FieldParser<StringAttr>::parse(odsParser); 2138 if (failed(objectResult)) 2139 return odsParser.emitError(odsParser.getCurrentLocation(), 2140 "failed to parse GPU_ObjectAttr parameter " 2141 "'object' which is to be a `StringAttr`"); 2142 format = *formatResult; 2143 object = *objectResult; 2144 return success(); 2145 } 2146 2147 void printObject(AsmPrinter &odsParser, CompilationTarget format, 2148 StringAttr object) { 2149 if (format != CompilationTarget::Fatbin) 2150 odsParser << stringifyEnum(format) << " = "; 2151 odsParser << object; 2152 } 2153 } // namespace 2154 2155 //===----------------------------------------------------------------------===// 2156 // GPU select object attribute 2157 //===----------------------------------------------------------------------===// 2158 2159 LogicalResult 2160 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError, 2161 Attribute target) { 2162 // Check `target`, it can be null, an integer attr or a GPU Target attribute. 2163 if (target) { 2164 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) { 2165 if (intAttr.getInt() < 0) { 2166 return emitError() << "the object index must be positive"; 2167 } 2168 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) { 2169 return emitError() 2170 << "the target attribute must be a GPU Target attribute"; 2171 } 2172 } 2173 return success(); 2174 } 2175 2176 //===----------------------------------------------------------------------===// 2177 // DynamicSharedMemoryOp 2178 //===----------------------------------------------------------------------===// 2179 2180 LogicalResult gpu::DynamicSharedMemoryOp::verify() { 2181 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>()) 2182 return emitOpError() << "must be inside an op with symbol table"; 2183 2184 MemRefType memrefType = getResultMemref().getType(); 2185 // Check address space 2186 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) { 2187 return emitOpError() << "address space must be " 2188 << gpu::AddressSpaceAttr::getMnemonic() << "<" 2189 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">"; 2190 } 2191 if (memrefType.hasStaticShape()) { 2192 return emitOpError() << "result memref type must be memref<?xi8, " 2193 "#gpu.address_space<workgroup>>"; 2194 } 2195 return success(); 2196 } 2197 2198 //===----------------------------------------------------------------------===// 2199 // GPU WarpExecuteOnLane0Op 2200 //===----------------------------------------------------------------------===// 2201 2202 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) { 2203 p << "(" << getLaneid() << ")"; 2204 2205 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()}; 2206 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName()); 2207 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]"; 2208 2209 if (!getArgs().empty()) 2210 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")"; 2211 if (!getResults().empty()) 2212 p << " -> (" << getResults().getTypes() << ')'; 2213 p << " "; 2214 p.printRegion(getRegion(), 2215 /*printEntryBlockArgs=*/true, 2216 /*printBlockTerminators=*/!getResults().empty()); 2217 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr); 2218 } 2219 2220 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser, 2221 OperationState &result) { 2222 // Create the region. 2223 result.regions.reserve(1); 2224 Region *warpRegion = result.addRegion(); 2225 2226 auto &builder = parser.getBuilder(); 2227 OpAsmParser::UnresolvedOperand laneId; 2228 2229 // Parse predicate operand. 2230 if (parser.parseLParen() || 2231 parser.parseOperand(laneId, /*allowResultNumber=*/false) || 2232 parser.parseRParen()) 2233 return failure(); 2234 2235 int64_t warpSize; 2236 if (parser.parseLSquare() || parser.parseInteger(warpSize) || 2237 parser.parseRSquare()) 2238 return failure(); 2239 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(), 2240 builder.getContext())), 2241 builder.getI64IntegerAttr(warpSize)); 2242 2243 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands)) 2244 return failure(); 2245 2246 llvm::SMLoc inputsOperandsLoc; 2247 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands; 2248 SmallVector<Type> inputTypes; 2249 if (succeeded(parser.parseOptionalKeyword("args"))) { 2250 if (parser.parseLParen()) 2251 return failure(); 2252 2253 inputsOperandsLoc = parser.getCurrentLocation(); 2254 if (parser.parseOperandList(inputsOperands) || 2255 parser.parseColonTypeList(inputTypes) || parser.parseRParen()) 2256 return failure(); 2257 } 2258 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, 2259 result.operands)) 2260 return failure(); 2261 2262 // Parse optional results type list. 2263 if (parser.parseOptionalArrowTypeList(result.types)) 2264 return failure(); 2265 // Parse the region. 2266 if (parser.parseRegion(*warpRegion, /*arguments=*/{}, 2267 /*argTypes=*/{})) 2268 return failure(); 2269 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location); 2270 2271 // Parse the optional attribute list. 2272 if (parser.parseOptionalAttrDict(result.attributes)) 2273 return failure(); 2274 return success(); 2275 } 2276 2277 void WarpExecuteOnLane0Op::getSuccessorRegions( 2278 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 2279 if (!point.isParent()) { 2280 regions.push_back(RegionSuccessor(getResults())); 2281 return; 2282 } 2283 2284 // The warp region is always executed 2285 regions.push_back(RegionSuccessor(&getWarpRegion())); 2286 } 2287 2288 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result, 2289 TypeRange resultTypes, Value laneId, 2290 int64_t warpSize) { 2291 build(builder, result, resultTypes, laneId, warpSize, 2292 /*operands=*/std::nullopt, /*argTypes=*/std::nullopt); 2293 } 2294 2295 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result, 2296 TypeRange resultTypes, Value laneId, 2297 int64_t warpSize, ValueRange args, 2298 TypeRange blockArgTypes) { 2299 result.addOperands(laneId); 2300 result.addAttribute(getAttributeNames()[0], 2301 builder.getI64IntegerAttr(warpSize)); 2302 result.addTypes(resultTypes); 2303 result.addOperands(args); 2304 assert(args.size() == blockArgTypes.size()); 2305 OpBuilder::InsertionGuard guard(builder); 2306 Region *warpRegion = result.addRegion(); 2307 Block *block = builder.createBlock(warpRegion); 2308 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args)) 2309 block->addArgument(type, arg.getLoc()); 2310 } 2311 2312 /// Helper check if the distributed vector type is consistent with the expanded 2313 /// type and distributed size. 2314 static LogicalResult verifyDistributedType(Type expanded, Type distributed, 2315 int64_t warpSize, Operation *op) { 2316 // If the types matches there is no distribution. 2317 if (expanded == distributed) 2318 return success(); 2319 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded); 2320 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed); 2321 if (!expandedVecType || !distributedVecType) 2322 return op->emitOpError("expected vector type for distributed operands."); 2323 if (expandedVecType.getRank() != distributedVecType.getRank() || 2324 expandedVecType.getElementType() != distributedVecType.getElementType()) 2325 return op->emitOpError( 2326 "expected distributed vectors to have same rank and element type."); 2327 2328 SmallVector<int64_t> scales(expandedVecType.getRank(), 1); 2329 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) { 2330 int64_t eDim = expandedVecType.getDimSize(i); 2331 int64_t dDim = distributedVecType.getDimSize(i); 2332 if (eDim == dDim) 2333 continue; 2334 if (eDim % dDim != 0) 2335 return op->emitOpError() 2336 << "expected expanded vector dimension #" << i << " (" << eDim 2337 << ") to be a multipler of the distributed vector dimension (" 2338 << dDim << ")"; 2339 scales[i] = eDim / dDim; 2340 } 2341 if (std::accumulate(scales.begin(), scales.end(), 1, 2342 std::multiplies<int64_t>()) != warpSize) 2343 return op->emitOpError() 2344 << "incompatible distribution dimensions from " << expandedVecType 2345 << " to " << distributedVecType << " with warp size = " << warpSize; 2346 2347 return success(); 2348 } 2349 2350 LogicalResult WarpExecuteOnLane0Op::verify() { 2351 if (getArgs().size() != getWarpRegion().getNumArguments()) 2352 return emitOpError( 2353 "expected same number op arguments and block arguments."); 2354 auto yield = 2355 cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator()); 2356 if (yield.getNumOperands() != getNumResults()) 2357 return emitOpError( 2358 "expected same number of yield operands and return values."); 2359 int64_t warpSize = getWarpSize(); 2360 for (auto [regionArg, arg] : 2361 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) { 2362 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(), 2363 warpSize, getOperation()))) 2364 return failure(); 2365 } 2366 for (auto [yieldOperand, result] : 2367 llvm::zip_equal(yield.getOperands(), getResults())) { 2368 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(), 2369 warpSize, getOperation()))) 2370 return failure(); 2371 } 2372 return success(); 2373 } 2374 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) { 2375 return succeeded( 2376 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation())); 2377 } 2378 2379 //===----------------------------------------------------------------------===// 2380 // GPU KernelMetadataAttr 2381 //===----------------------------------------------------------------------===// 2382 2383 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel, 2384 DictionaryAttr metadata) { 2385 assert(kernel && "invalid kernel"); 2386 return get(kernel.getNameAttr(), kernel.getFunctionType(), 2387 kernel.getAllArgAttrs(), metadata); 2388 } 2389 2390 KernelMetadataAttr 2391 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, 2392 FunctionOpInterface kernel, 2393 DictionaryAttr metadata) { 2394 assert(kernel && "invalid kernel"); 2395 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(), 2396 kernel.getAllArgAttrs(), metadata); 2397 } 2398 2399 KernelMetadataAttr 2400 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const { 2401 if (attrs.empty()) 2402 return *this; 2403 NamedAttrList attrList; 2404 if (DictionaryAttr dict = getMetadata()) 2405 attrList.append(dict); 2406 attrList.append(attrs); 2407 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(), 2408 attrList.getDictionary(getContext())); 2409 } 2410 2411 LogicalResult 2412 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError, 2413 StringAttr name, Type functionType, 2414 ArrayAttr argAttrs, DictionaryAttr metadata) { 2415 if (name.empty()) 2416 return emitError() << "the kernel name can't be empty"; 2417 if (argAttrs) { 2418 if (llvm::any_of(argAttrs, [](Attribute attr) { 2419 return !llvm::isa<DictionaryAttr>(attr); 2420 })) 2421 return emitError() 2422 << "all attributes in the array must be a dictionary attribute"; 2423 } 2424 return success(); 2425 } 2426 2427 //===----------------------------------------------------------------------===// 2428 // GPU KernelTableAttr 2429 //===----------------------------------------------------------------------===// 2430 2431 KernelTableAttr KernelTableAttr::get(MLIRContext *context, 2432 ArrayRef<KernelMetadataAttr> kernels, 2433 bool isSorted) { 2434 // Note that `is_sorted` is always only invoked once even with assertions ON. 2435 assert((!isSorted || llvm::is_sorted(kernels)) && 2436 "expected a sorted kernel array"); 2437 // Immediately return the attribute if the array is sorted. 2438 if (isSorted || llvm::is_sorted(kernels)) 2439 return Base::get(context, kernels); 2440 // Sort the array. 2441 SmallVector<KernelMetadataAttr> kernelsTmp(kernels); 2442 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end()); 2443 return Base::get(context, kernelsTmp); 2444 } 2445 2446 KernelTableAttr KernelTableAttr::getChecked( 2447 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context, 2448 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) { 2449 // Note that `is_sorted` is always only invoked once even with assertions ON. 2450 assert((!isSorted || llvm::is_sorted(kernels)) && 2451 "expected a sorted kernel array"); 2452 // Immediately return the attribute if the array is sorted. 2453 if (isSorted || llvm::is_sorted(kernels)) 2454 return Base::getChecked(emitError, context, kernels); 2455 // Sort the array. 2456 SmallVector<KernelMetadataAttr> kernelsTmp(kernels); 2457 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end()); 2458 return Base::getChecked(emitError, context, kernelsTmp); 2459 } 2460 2461 LogicalResult 2462 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError, 2463 ArrayRef<KernelMetadataAttr> kernels) { 2464 if (kernels.size() < 2) 2465 return success(); 2466 // Check that the kernels are uniquely named. 2467 if (std::adjacent_find(kernels.begin(), kernels.end(), 2468 [](KernelMetadataAttr l, KernelMetadataAttr r) { 2469 return l.getName() == r.getName(); 2470 }) != kernels.end()) { 2471 return emitError() << "expected all kernels to be uniquely named"; 2472 } 2473 return success(); 2474 } 2475 2476 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const { 2477 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key); 2478 return found ? *iterator : KernelMetadataAttr(); 2479 } 2480 2481 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const { 2482 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key); 2483 return found ? *iterator : KernelMetadataAttr(); 2484 } 2485 2486 //===----------------------------------------------------------------------===// 2487 // GPU target options 2488 //===----------------------------------------------------------------------===// 2489 2490 TargetOptions::TargetOptions( 2491 StringRef toolkitPath, ArrayRef<Attribute> librariesToLink, 2492 StringRef cmdOptions, StringRef elfSection, 2493 CompilationTarget compilationTarget, 2494 function_ref<SymbolTable *()> getSymbolTableCallback, 2495 function_ref<void(llvm::Module &)> initialLlvmIRCallback, 2496 function_ref<void(llvm::Module &)> linkedLlvmIRCallback, 2497 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, 2498 function_ref<void(StringRef)> isaCallback) 2499 : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink, 2500 cmdOptions, elfSection, compilationTarget, 2501 getSymbolTableCallback, initialLlvmIRCallback, 2502 linkedLlvmIRCallback, optimizedLlvmIRCallback, 2503 isaCallback) {} 2504 2505 TargetOptions::TargetOptions( 2506 TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink, 2507 StringRef cmdOptions, StringRef elfSection, 2508 CompilationTarget compilationTarget, 2509 function_ref<SymbolTable *()> getSymbolTableCallback, 2510 function_ref<void(llvm::Module &)> initialLlvmIRCallback, 2511 function_ref<void(llvm::Module &)> linkedLlvmIRCallback, 2512 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, 2513 function_ref<void(StringRef)> isaCallback) 2514 : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink), 2515 cmdOptions(cmdOptions.str()), elfSection(elfSection.str()), 2516 compilationTarget(compilationTarget), 2517 getSymbolTableCallback(getSymbolTableCallback), 2518 initialLlvmIRCallback(initialLlvmIRCallback), 2519 linkedLlvmIRCallback(linkedLlvmIRCallback), 2520 optimizedLlvmIRCallback(optimizedLlvmIRCallback), 2521 isaCallback(isaCallback), typeID(typeID) {} 2522 2523 TypeID TargetOptions::getTypeID() const { return typeID; } 2524 2525 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; } 2526 2527 ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const { 2528 return librariesToLink; 2529 } 2530 2531 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; } 2532 2533 StringRef TargetOptions::getELFSection() const { return elfSection; } 2534 2535 SymbolTable *TargetOptions::getSymbolTable() const { 2536 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr; 2537 } 2538 2539 function_ref<void(llvm::Module &)> 2540 TargetOptions::getInitialLlvmIRCallback() const { 2541 return initialLlvmIRCallback; 2542 } 2543 2544 function_ref<void(llvm::Module &)> 2545 TargetOptions::getLinkedLlvmIRCallback() const { 2546 return linkedLlvmIRCallback; 2547 } 2548 2549 function_ref<void(llvm::Module &)> 2550 TargetOptions::getOptimizedLlvmIRCallback() const { 2551 return optimizedLlvmIRCallback; 2552 } 2553 2554 function_ref<void(StringRef)> TargetOptions::getISACallback() const { 2555 return isaCallback; 2556 } 2557 2558 CompilationTarget TargetOptions::getCompilationTarget() const { 2559 return compilationTarget; 2560 } 2561 2562 CompilationTarget TargetOptions::getDefaultCompilationTarget() { 2563 return CompilationTarget::Fatbin; 2564 } 2565 2566 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> 2567 TargetOptions::tokenizeCmdOptions() const { 2568 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options; 2569 llvm::StringSaver stringSaver(options.first); 2570 StringRef opts = cmdOptions; 2571 // For a correct tokenization of the command line options `opts` must be 2572 // unquoted, otherwise the tokenization function returns a single string: the 2573 // unquoted `cmdOptions` -which is not the desired behavior. 2574 // Remove any quotes if they are at the beginning and end of the string: 2575 if (!opts.empty() && opts.front() == '"' && opts.back() == '"') 2576 opts.consume_front("\""), opts.consume_back("\""); 2577 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'') 2578 opts.consume_front("'"), opts.consume_back("'"); 2579 #ifdef _WIN32 2580 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second, 2581 /*MarkEOLs=*/false); 2582 #else 2583 llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second, 2584 /*MarkEOLs=*/false); 2585 #endif // _WIN32 2586 return options; 2587 } 2588 2589 MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::gpu::TargetOptions) 2590 2591 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc" 2592 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc" 2593 2594 #define GET_ATTRDEF_CLASSES 2595 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc" 2596 2597 #define GET_OP_CLASSES 2598 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc" 2599 2600 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc" 2601