1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the NVVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 // The NVVM dialect only contains GPU specific additions on top of the general 13 // LLVM dialect. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 18 19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 20 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/IR/BuiltinAttributes.h" 24 #include "mlir/IR/BuiltinTypes.h" 25 #include "mlir/IR/Diagnostics.h" 26 #include "mlir/IR/DialectImplementation.h" 27 #include "mlir/IR/MLIRContext.h" 28 #include "mlir/IR/Operation.h" 29 #include "mlir/IR/OperationSupport.h" 30 #include "mlir/IR/Types.h" 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/AsmParser/Parser.h" 34 #include "llvm/IR/Attributes.h" 35 #include "llvm/IR/Function.h" 36 #include "llvm/IR/Type.h" 37 #include "llvm/Support/Casting.h" 38 #include "llvm/Support/SourceMgr.h" 39 #include "llvm/Support/raw_ostream.h" 40 #include <cassert> 41 #include <optional> 42 #include <string> 43 44 using namespace mlir; 45 using namespace NVVM; 46 47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc" 48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc" 49 50 //===----------------------------------------------------------------------===// 51 // Printing/parsing for NVVM ops 52 //===----------------------------------------------------------------------===// 53 54 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) { 55 p << " " << op->getOperands(); 56 if (op->getNumResults() > 0) 57 p << " : " << op->getResultTypes(); 58 } 59 60 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type 61 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) { 62 MLIRContext *context = parser.getContext(); 63 auto int32Ty = IntegerType::get(context, 32); 64 auto int1Ty = IntegerType::get(context, 1); 65 66 SmallVector<OpAsmParser::UnresolvedOperand, 8> ops; 67 Type type; 68 return failure(parser.parseOperandList(ops) || 69 parser.parseOptionalAttrDict(result.attributes) || 70 parser.parseColonType(type) || 71 parser.addTypeToList(type, result.types) || 72 parser.resolveOperands(ops, {int32Ty, int1Ty}, 73 parser.getNameLoc(), result.operands)); 74 } 75 76 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); } 77 78 // This verifier is shared across: 79 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) and 80 // CpAsyncBulkTensorPrefetchOp (TMA Prefetch) Ops. 81 static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims, 82 size_t numIm2ColOffsets, 83 Location loc) { 84 if (tensorDims < 1 || tensorDims > 5) 85 return emitError(loc, "expects coordinates between 1 to 5 dimension"); 86 87 if (numIm2ColOffsets) { 88 if (tensorDims < 3) 89 return emitError( 90 loc, 91 "to use im2col mode, the tensor has to be at least 3-dimensional"); 92 if (tensorDims != (numIm2ColOffsets + 2)) 93 return emitError( 94 loc, "im2col offsets must be 2 less than number of coordinates"); 95 } 96 return success(); 97 } 98 99 LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() { 100 return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), 101 getIm2colOffsets().size(), getLoc()); 102 } 103 104 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() { 105 if (getCoordinates().size() > 5) 106 return emitError("Maximum 5 coordinates and dimension is supported."); 107 return success(); 108 } 109 110 LogicalResult CpAsyncOp::verify() { 111 if (getModifier() != LoadCacheModifierKind::CG && 112 getModifier() != LoadCacheModifierKind::CA) 113 return emitError("Only CG and CA cache modifiers are supported."); 114 if (getSize() != 4 && getSize() != 8 && getSize() != 16) 115 return emitError("expected byte size to be either 4, 8 or 16."); 116 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16) 117 return emitError("CG cache modifier is only support for 16 bytes copy."); 118 return success(); 119 } 120 121 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() { 122 return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), 123 getIm2colOffsets().size(), getLoc()); 124 } 125 126 // Given the element type of an operand and whether or not it is an accumulator, 127 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the 128 // operand's element type. 129 std::optional<mlir::NVVM::MMATypes> 130 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) { 131 auto half2Type = 132 LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2); 133 if (operandElType.isF64()) 134 return NVVM::MMATypes::f64; 135 if (operandElType.isF16() || operandElType == half2Type) 136 return NVVM::MMATypes::f16; 137 if (operandElType.isF32() && isAccumulator) 138 return NVVM::MMATypes::f32; 139 if (operandElType.isF32() && !isAccumulator) 140 return NVVM::MMATypes::tf32; 141 if (llvm::isa<IntegerType>(operandElType)) { 142 if (isAccumulator) 143 return NVVM::MMATypes::s32; 144 return std::nullopt; 145 } 146 147 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) { 148 if (structType.getBody().empty()) 149 return std::nullopt; 150 return inferOperandMMAType(structType.getBody()[0], isAccumulator); 151 } 152 153 return std::nullopt; 154 } 155 156 static bool isInt4PtxType(MMATypes type) { 157 return (type == MMATypes::u4 || type == MMATypes::s4); 158 } 159 160 static bool isInt8PtxType(MMATypes type) { 161 return (type == MMATypes::u8 || type == MMATypes::s8); 162 } 163 164 static bool isIntegerPtxType(MMATypes type) { 165 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 || 166 type == MMATypes::s32; 167 } 168 169 MMATypes MmaOp::accumPtxType() { 170 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType( 171 getODSOperands(2).getTypes().front(), /*isAccum=*/true); 172 assert(val.has_value() && "accumulator PTX type should always be inferrable"); 173 return val.value(); 174 } 175 176 MMATypes MmaOp::resultPtxType() { 177 std::optional<mlir::NVVM::MMATypes> val = 178 inferOperandMMAType(getResult().getType(), /*isAccum=*/true); 179 assert(val.has_value() && "result PTX type should always be inferrable"); 180 return val.value(); 181 } 182 183 void MmaOp::print(OpAsmPrinter &p) { 184 SmallVector<Type, 4> regTypes; 185 struct OperandFragment { 186 StringRef operandName; 187 StringRef ptxTypeAttr; 188 SmallVector<Value, 4> regs; 189 explicit OperandFragment(StringRef name, StringRef ptxTypeName) 190 : operandName(name), ptxTypeAttr(ptxTypeName) {} 191 }; 192 193 std::array<OperandFragment, 3> frags{ 194 OperandFragment("A", getMultiplicandAPtxTypeAttrName()), 195 OperandFragment("B", getMultiplicandBPtxTypeAttrName()), 196 OperandFragment("C", "")}; 197 SmallVector<StringRef, 4> ignoreAttrNames{ 198 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()}; 199 200 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) { 201 auto &frag = frags[fragIdx]; 202 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx); 203 for (auto operandIdx = varOperandSpec.first; 204 operandIdx < varOperandSpec.first + varOperandSpec.second; 205 operandIdx++) { 206 frag.regs.push_back(this->getOperand(operandIdx)); 207 if (operandIdx == 0) { 208 regTypes.push_back(this->getOperand(operandIdx).getType()); 209 } 210 } 211 std::optional<MMATypes> inferredType = 212 inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2); 213 if (inferredType) 214 ignoreAttrNames.push_back(frag.ptxTypeAttr); 215 } 216 217 auto printMmaOperand = [&](const OperandFragment &frag) -> void { 218 p << " " << frag.operandName; 219 p << "["; 220 p.printOperands(frag.regs); 221 p << "] "; 222 }; 223 224 for (const auto &frag : frags) { 225 printMmaOperand(frag); 226 } 227 228 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); 229 230 // Print the types of the operands and result. 231 p << " : " << "("; 232 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(), 233 frags[1].regs[0].getType(), 234 frags[2].regs[0].getType()}, 235 p); 236 p << ")"; 237 p.printArrowTypeList(TypeRange{this->getRes().getType()}); 238 } 239 240 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType, 241 ValueRange operandA, ValueRange operandB, ValueRange operandC, 242 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op, 243 std::optional<MMAIntOverflow> intOverflow, 244 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes, 245 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) { 246 247 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); 248 MLIRContext *ctx = builder.getContext(); 249 result.addAttribute( 250 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2])); 251 252 result.addOperands(operandA); 253 result.addOperands(operandB); 254 result.addOperands(operandC); 255 256 if (multiplicandPtxTypes) { 257 result.addAttribute("multiplicandAPtxType", 258 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); 259 result.addAttribute("multiplicandBPtxType", 260 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); 261 } else { 262 if (auto res = inferOperandMMAType(operandA[0].getType(), false)) 263 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); 264 if (auto res = inferOperandMMAType(operandB[0].getType(), false)) 265 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); 266 } 267 268 if (multiplicandLayouts) { 269 result.addAttribute("layoutA", 270 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0])); 271 result.addAttribute("layoutB", 272 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1])); 273 } else { 274 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row)); 275 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col)); 276 } 277 278 if (intOverflow.has_value()) 279 result.addAttribute("intOverflowBehavior", 280 MMAIntOverflowAttr::get(ctx, *intOverflow)); 281 if (b1Op.has_value()) 282 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op)); 283 284 result.addTypes(resultType); 285 result.addAttribute( 286 MmaOp::getOperandSegmentSizeAttr(), 287 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()), 288 static_cast<int32_t>(operandB.size()), 289 static_cast<int32_t>(operandC.size())})); 290 } 291 292 // <operation> := 293 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]` 294 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0])) 295 // `->` type($res) 296 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) { 297 struct OperandFragment { 298 std::optional<MMATypes> elemtype; 299 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs; 300 SmallVector<Type> regTypes; 301 }; 302 303 Builder &builder = parser.getBuilder(); 304 std::array<OperandFragment, 4> frags; 305 306 NamedAttrList namedAttributes; 307 308 // A helper to parse the operand segments. 309 auto parseMmaOperand = [&](StringRef operandName, 310 OperandFragment &frag) -> LogicalResult { 311 if (parser.parseKeyword(operandName).failed()) 312 return failure(); 313 if (parser 314 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare) 315 .failed()) 316 return failure(); 317 return success(); 318 }; 319 320 // Parse the operand segments. 321 if (parseMmaOperand("A", frags[0]).failed()) 322 return failure(); 323 if (parseMmaOperand("B", frags[1]).failed()) 324 return failure(); 325 if (parseMmaOperand("C", frags[2]).failed()) 326 return failure(); 327 328 if (parser.parseOptionalAttrDict(namedAttributes).failed()) 329 return failure(); 330 331 // Parse the type specification and resolve operands. 332 SmallVector<Type, 3> operandTypes; 333 if (failed(parser.parseColon())) 334 return failure(); 335 if (failed(parser.parseLParen())) 336 return failure(); 337 if (failed(parser.parseTypeList(operandTypes))) 338 return failure(); 339 if (failed(parser.parseRParen())) 340 if (operandTypes.size() != 3) 341 return parser.emitError( 342 parser.getNameLoc(), 343 "expected one type for each operand segment but got " + 344 Twine(operandTypes.size()) + " types"); 345 for (const auto &iter : llvm::enumerate(operandTypes)) { 346 auto &frag = frags[iter.index()]; 347 frag.regTypes.resize(frag.regs.size(), iter.value()); 348 if (failed(parser.resolveOperands(frag.regs, frag.regTypes, 349 parser.getNameLoc(), result.operands))) 350 return failure(); 351 frag.elemtype = 352 inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2); 353 } 354 355 Type resultType; 356 if (parser.parseArrow() || parser.parseType(resultType)) 357 return failure(); 358 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true); 359 360 std::array<StringRef, 2> names{"multiplicandAPtxType", 361 "multiplicandBPtxType"}; 362 for (unsigned idx = 0; idx < names.size(); idx++) { 363 const auto &frag = frags[idx]; 364 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]); 365 if (!frag.elemtype.has_value() && !attr.has_value()) { 366 return parser.emitError( 367 parser.getNameLoc(), 368 "attribute " + names[idx] + 369 " is not provided explicitly and cannot be inferred"); 370 } 371 if (!attr.has_value()) 372 result.addAttribute( 373 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype)); 374 } 375 376 result.addTypes(resultType); 377 if (!namedAttributes.empty()) 378 result.addAttributes(namedAttributes); 379 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(), 380 builder.getDenseI32ArrayAttr({ 381 static_cast<int32_t>(frags[0].regs.size()), 382 static_cast<int32_t>(frags[1].regs.size()), 383 static_cast<int32_t>(frags[2].regs.size()), 384 })); 385 return success(); 386 } 387 388 LogicalResult MmaOp::verify() { 389 MLIRContext *context = getContext(); 390 auto f16Ty = Float16Type::get(context); 391 auto i32Ty = IntegerType::get(context, 32); 392 auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2); 393 auto f32Ty = Float32Type::get(context); 394 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( 395 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); 396 397 auto s32x4StructTy = 398 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty}); 399 auto f32x8StructTy = 400 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty)); 401 auto f16x2x2StructTy = 402 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty}); 403 auto f32x4StructTy = 404 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty}); 405 auto s32x2StructTy = 406 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); 407 408 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(), 409 getShapeAttr().getK()}; 410 411 // These variables define the set of allowed data types for matrices A, B, C, 412 // and result. 413 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>; 414 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>; 415 AllowedShapes allowedShapes; 416 AllowedTypes expectedA; 417 AllowedTypes expectedB; 418 AllowedTypes expectedC; 419 SmallVector<Type> expectedResult; 420 421 // When M = 16, we just need to calculate the number of 8xk tiles, where 422 // k is a factor that depends on the data type. 423 if (mmaShape[0] == 16) { 424 int64_t kFactor; 425 Type multiplicandFragType; 426 switch (*getMultiplicandAPtxType()) { 427 case MMATypes::tf32: 428 kFactor = 4; 429 multiplicandFragType = i32Ty; 430 expectedResult.push_back(LLVM::LLVMStructType::getLiteral( 431 context, {f32Ty, f32Ty, f32Ty, f32Ty})); 432 break; 433 case MMATypes::f16: 434 case MMATypes::bf16: 435 kFactor = 8; 436 multiplicandFragType = f16x2Ty; 437 expectedResult.push_back(f16x2x2StructTy); 438 expectedResult.push_back(f32x4StructTy); 439 break; 440 case MMATypes::s4: 441 case MMATypes::u4: 442 kFactor = 32; 443 break; 444 case MMATypes::b1: 445 kFactor = 128; 446 break; 447 case MMATypes::s8: 448 case MMATypes::u8: 449 kFactor = 16; 450 break; 451 default: 452 return emitError("invalid shape or multiplicand type: " + 453 stringifyEnum(getMultiplicandAPtxType().value())); 454 } 455 456 if (isIntegerPtxType(getMultiplicandAPtxType().value())) { 457 expectedResult.push_back(s32x4StructTy); 458 expectedC.emplace_back(4, i32Ty); 459 multiplicandFragType = i32Ty; 460 } else { 461 expectedC.emplace_back(2, f16x2Ty); 462 expectedC.emplace_back(4, f32Ty); 463 } 464 465 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor); 466 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor); 467 expectedA.emplace_back(unitA, multiplicandFragType); 468 expectedB.emplace_back(unitB, multiplicandFragType); 469 allowedShapes.push_back({16, 8, kFactor}); 470 allowedShapes.push_back({16, 8, kFactor * 2}); 471 } 472 473 // In the M=8 case, there is only 1 possible case per data type. 474 if (mmaShape[0] == 8) { 475 if (*getMultiplicandAPtxType() == MMATypes::f16) { 476 expectedA.emplace_back(2, f16x2Ty); 477 expectedB.emplace_back(2, f16x2Ty); 478 expectedResult.push_back(f16x2x4StructTy); 479 expectedResult.push_back(f32x8StructTy); 480 expectedC.emplace_back(4, f16x2Ty); 481 expectedC.emplace_back(8, f32Ty); 482 allowedShapes.push_back({8, 8, 4}); 483 } 484 if (*getMultiplicandAPtxType() == MMATypes::f64) { 485 Type f64Ty = Float64Type::get(context); 486 expectedA.emplace_back(1, f64Ty); 487 expectedB.emplace_back(1, f64Ty); 488 expectedC.emplace_back(2, f64Ty); 489 // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2)); 490 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( 491 context, SmallVector<Type>(2, f64Ty))); 492 allowedShapes.push_back({8, 8, 4}); 493 } 494 if (isIntegerPtxType(getMultiplicandAPtxType().value())) { 495 expectedA.push_back({i32Ty}); 496 expectedB.push_back({i32Ty}); 497 expectedC.push_back({i32Ty, i32Ty}); 498 expectedResult.push_back(s32x2StructTy); 499 if (isInt4PtxType(getMultiplicandAPtxType().value())) 500 allowedShapes.push_back({8, 8, 32}); 501 if (isInt8PtxType(getMultiplicandAPtxType().value())) 502 allowedShapes.push_back({8, 8, 16}); 503 if (getMultiplicandAPtxType().value() == MMATypes::b1) 504 allowedShapes.push_back({8, 8, 128}); 505 } 506 } 507 508 std::string errorMessage; 509 llvm::raw_string_ostream errorStream(errorMessage); 510 511 // Check that we matched an existing shape/dtype combination. 512 if (expectedA.empty() || expectedB.empty() || expectedC.empty() || 513 !llvm::is_contained(allowedShapes, mmaShape)) { 514 errorStream << "unimplemented variant for MMA shape <"; 515 llvm::interleaveComma(mmaShape, errorStream); 516 errorStream << ">"; 517 return emitOpError(errorMessage); 518 } 519 520 // Verify the operand types for segments of A, B, and C operands. 521 std::array<StringRef, 3> operandNames{"A", "B", "C"}; 522 for (const auto &iter : llvm::enumerate( 523 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) { 524 auto spec = this->getODSOperandIndexAndLength(iter.index()); 525 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first, 526 operand_type_begin() + spec.first + 527 spec.second); 528 bool match = llvm::is_contained(iter.value(), operandTySeg); 529 530 if (!match) { 531 errorStream << "Could not match types for the " 532 << operandNames[iter.index()] 533 << " operands; expected one of "; 534 for (const auto &x : iter.value()) { 535 errorStream << x.size() << "x" << x[0] << " "; 536 } 537 errorStream << "but got "; 538 llvm::interleaveComma(operandTySeg, errorStream); 539 return emitOpError(errorMessage); 540 } 541 } 542 543 // Check the result type 544 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) { 545 return expectedResultType == getResult().getType(); 546 })) { 547 errorStream 548 << "Could not match allowed types for the result; expected one of "; 549 llvm::interleaveComma(expectedResult, errorStream); 550 errorStream << " but got " << getResult().getType(); 551 return emitOpError(errorMessage); 552 } 553 554 // Ensure that binary MMA variants have a b1 MMA operation defined. 555 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) { 556 return emitOpError("op requires " + getB1OpAttrName().strref() + 557 " attribute"); 558 } 559 560 // Ensure int4/int8 MMA variants specify the accum overflow behavior 561 // attribute. 562 if (isInt4PtxType(*getMultiplicandAPtxType()) || 563 isInt8PtxType(*getMultiplicandAPtxType())) { 564 if (!getIntOverflowBehavior()) 565 return emitOpError("op requires " + 566 getIntOverflowBehaviorAttrName().strref() + 567 " attribute"); 568 } 569 570 return success(); 571 } 572 573 LogicalResult ShflOp::verify() { 574 if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid")) 575 return success(); 576 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); 577 auto elementType = (type && type.getBody().size() == 2) 578 ? llvm::dyn_cast<IntegerType>(type.getBody()[1]) 579 : nullptr; 580 if (!elementType || elementType.getWidth() != 1) 581 return emitError("expected return type to be a two-element struct with " 582 "i1 as the second element"); 583 return success(); 584 } 585 586 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type, 587 NVVM::MMAFrag frag, int nRow, 588 int nCol, 589 MLIRContext *context) { 590 unsigned numberElements = 0; 591 Type elementType; 592 OpBuilder builder(context); 593 Type f16x2 = VectorType::get(2, builder.getF16Type()); 594 if (type == NVVM::MMATypes::f16) { 595 elementType = f16x2; 596 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) 597 numberElements = 8; 598 else 599 numberElements = 4; 600 } else if (type == NVVM::MMATypes::f32) { 601 elementType = builder.getF32Type(); 602 numberElements = 8; 603 } else if (type == NVVM::MMATypes::tf32) { 604 elementType = builder.getI32Type(); 605 numberElements = 4; 606 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) { 607 elementType = builder.getI32Type(); 608 int parallelSize = 0; 609 if (frag == NVVM::MMAFrag::a) 610 parallelSize = nRow; 611 if (frag == NVVM::MMAFrag::b) 612 parallelSize = nCol; 613 614 // m == 16 && n == 16 && k == 16 615 if (parallelSize == 16) 616 numberElements = 2; 617 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16 618 else if (parallelSize == 8) 619 numberElements = 1; 620 else if (parallelSize == 32) 621 numberElements = 4; 622 } else if (type == NVVM::MMATypes::s32) { 623 elementType = builder.getI32Type(); 624 numberElements = 8; 625 } 626 assert(numberElements != 0 && elementType != nullptr); 627 return std::make_pair(elementType, numberElements); 628 } 629 630 static std::pair<mlir::Type, unsigned> 631 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, 632 int k, MLIRContext *context) { 633 int nRow, nCol; 634 if (frag == NVVM::MMAFrag::a) { 635 nRow = m; 636 nCol = k; 637 } else if (frag == NVVM::MMAFrag::b) { 638 nRow = k; 639 nCol = n; 640 } else { 641 nRow = m; 642 nCol = n; 643 } 644 assert(nRow && nCol); 645 return inferMMAType(type, frag, nRow, nCol, context); 646 } 647 648 LogicalResult NVVM::WMMALoadOp::verify() { 649 unsigned addressSpace = 650 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 651 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace && 652 addressSpace != NVVM::kSharedMemorySpace) 653 return emitOpError("expected source pointer in memory " 654 "space 0, 1, 3"); 655 656 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), 657 getEltype(), getFrag()) == 0) 658 return emitOpError() << "invalid attribute combination"; 659 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( 660 getEltype(), getFrag(), getM(), getN(), getK(), getContext()); 661 Type dstType = LLVM::LLVMStructType::getLiteral( 662 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); 663 if (getType() != dstType) 664 return emitOpError("expected destination type is a structure of ") 665 << typeInfo.second << " elements of type " << typeInfo.first; 666 return success(); 667 } 668 669 LogicalResult NVVM::WMMAStoreOp::verify() { 670 unsigned addressSpace = 671 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 672 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace && 673 addressSpace != NVVM::kSharedMemorySpace) 674 return emitOpError("expected operands to be a source pointer in memory " 675 "space 0, 1, 3"); 676 677 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), 678 getEltype()) == 0) 679 return emitOpError() << "invalid attribute combination"; 680 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( 681 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); 682 if (getArgs().size() != typeInfo.second) 683 return emitOpError() << "expected " << typeInfo.second << " data operands"; 684 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) { 685 return operands.getType() != typeInfo.first; 686 })) 687 return emitOpError() << "expected data operands of type " << typeInfo.first; 688 return success(); 689 } 690 691 LogicalResult NVVM::WMMAMmaOp::verify() { 692 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(), 693 getLayoutB(), getEltypeA(), 694 getEltypeB()) == 0) 695 return emitOpError() << "invalid attribute combination"; 696 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK( 697 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext()); 698 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK( 699 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext()); 700 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK( 701 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); 702 SmallVector<Type, 32> arguments; 703 arguments.append(typeInfoA.second, typeInfoA.first); 704 arguments.append(typeInfoB.second, typeInfoB.first); 705 arguments.append(typeInfoC.second, typeInfoC.first); 706 unsigned numArgs = arguments.size(); 707 if (getArgs().size() != numArgs) 708 return emitOpError() << "expected " << numArgs << " arguments"; 709 for (unsigned i = 0; i < numArgs; i++) { 710 if (getArgs()[i].getType() != arguments[i]) 711 return emitOpError() << "expected argument " << i << " to be of type " 712 << arguments[i]; 713 } 714 Type dstType = LLVM::LLVMStructType::getLiteral( 715 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first)); 716 if (getType() != dstType) 717 return emitOpError("expected destination type is a structure of ") 718 << typeInfoC.second << " elements of type " << typeInfoC.first; 719 return success(); 720 } 721 722 LogicalResult NVVM::LdMatrixOp::verify() { 723 unsigned addressSpace = 724 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 725 if (addressSpace != NVVM::kSharedMemorySpace) 726 return emitOpError("expected source pointer in memory space 3"); 727 728 if (getNum() != 1 && getNum() != 2 && getNum() != 4) 729 return emitOpError("expected num attribute to be 1, 2 or 4"); 730 731 Type i32 = IntegerType::get(getContext(), 32); 732 if (getNum() == 1 && getType() != i32) 733 return emitOpError("expected destination type is i32"); 734 if (getNum() == 2 || getNum() == 4) { 735 Type dstType = LLVM::LLVMStructType::getLiteral( 736 getContext(), SmallVector<Type>(getNum(), i32)); 737 if (getType() != dstType) 738 return emitOpError("expected destination type is a structure of ") 739 << getNum() << " elements of type i32"; 740 } 741 return success(); 742 } 743 744 LogicalResult NVVM::StMatrixOp::verify() { 745 unsigned addressSpace = 746 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 747 if (addressSpace != NVVM::kSharedMemorySpace) 748 return emitOpError("expected source pointer in memory space 3"); 749 750 int numMatrix = getSources().size(); 751 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4) 752 return emitOpError("expected num attribute to be 1, 2 or 4"); 753 754 return success(); 755 } 756 757 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) { 758 if (typeA == NVVM::WGMMATypes::tf32) 759 return 8; 760 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16) 761 return 16; 762 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8) 763 return 32; 764 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2) 765 return 32; 766 if (typeA == NVVM::WGMMATypes::b1) 767 return 256; 768 return failure(); 769 } 770 771 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, 772 NVVM::WGMMATypes typeA, 773 NVVM::WGMMATypes typeB) { 774 switch (typeA) { 775 case NVVM::WGMMATypes::f16: 776 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 777 typeB == NVVM::WGMMATypes::f16) 778 return success(); 779 break; 780 case NVVM::WGMMATypes::tf32: 781 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32) 782 return success(); 783 break; 784 case NVVM::WGMMATypes::u8: 785 case NVVM::WGMMATypes::s8: 786 if (typeD == NVVM::WGMMATypes::s32 && 787 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8)) 788 return success(); 789 break; 790 case NVVM::WGMMATypes::b1: 791 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1) 792 return success(); 793 break; 794 case NVVM::WGMMATypes::bf16: 795 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 796 typeB == NVVM::WGMMATypes::bf16) 797 return success(); 798 break; 799 case NVVM::WGMMATypes::e4m3: 800 case NVVM::WGMMATypes::e5m2: 801 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 802 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3)) 803 return success(); 804 break; 805 case WGMMATypes::f32: 806 case WGMMATypes::s32: 807 llvm_unreachable("unsupported input types"); 808 break; 809 } 810 return failure(); 811 } 812 813 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) { 814 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64, 815 72, 80, 88, 96, 104, 112, 120, 128, 816 136, 144, 152, 160, 168, 176, 184, 192, 817 200, 208, 216, 224, 232, 240, 248, 256}; 818 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64, 819 80, 96, 112, 128, 144, 160, 820 176, 192, 208, 224, 240, 256}; 821 switch (typeA) { 822 case WGMMATypes::f16: 823 case WGMMATypes::tf32: 824 case WGMMATypes::bf16: 825 case WGMMATypes::e4m3: 826 case WGMMATypes::e5m2: 827 if (llvm::is_contained(allowedN, sizeN)) 828 return success(); 829 break; 830 case WGMMATypes::u8: 831 case WGMMATypes::s8: 832 case WGMMATypes::b1: 833 if (llvm::is_contained(allowedNshort, sizeN)) 834 return success(); 835 break; 836 case WGMMATypes::f32: 837 case WGMMATypes::s32: 838 llvm_unreachable("unsupported input types"); 839 break; 840 } 841 return failure(); 842 } 843 844 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() { 845 Value outValue = getResults(); 846 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType()); 847 if (!stype) 848 return emitOpError() << "expected results to be struct"; 849 int outputSize = stype.getBody().size(); 850 WGMMATypes typeD = getTypeD(); 851 WGMMATypes typeA = getTypeA(); 852 WGMMATypes typeB = getTypeB(); 853 854 for (Type t : stype.getBody()) { 855 if (t != stype.getBody().front()) 856 return emitOpError() 857 << "all elements in struct must be same type but there is " << t; 858 } 859 860 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 && 861 typeD != WGMMATypes::s32) { 862 return emitOpError() << "does not support the given output type " 863 << NVVM::stringifyWGMMATypes(typeD); 864 } 865 if (typeD == WGMMATypes::s32 && 866 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) { 867 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg"; 868 } 869 870 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) { 871 return emitOpError() << NVVM::stringifyWGMMATypes(typeD) 872 << " += " << NVVM::stringifyWGMMATypes(typeA) << " * " 873 << NVVM::stringifyWGMMATypes(typeB) 874 << ", it is not supported."; 875 } 876 877 // Check M 878 if (getShape().getM() != 64) 879 return emitOpError() << "shape 'm' must be 64"; 880 881 // Check K 882 FailureOr<int> allowedK = getAllowedSizeK(typeA); 883 if (failed(allowedK) || allowedK.value() != getShape().getK()) 884 return emitOpError() << "shape 'k' must be " << allowedK.value() 885 << " for input type " 886 << NVVM::stringifyWGMMATypes(typeA); 887 888 // Check N 889 if (failed(isAllowedSizeN(getShape().getN(), typeA))) { 890 return emitOpError() << "has input type " 891 << NVVM::stringifyWGMMATypes(typeA) << " n is set to " 892 << getShape().getN() << ", it is not supported."; 893 } 894 895 // Check transpose (only available for f16/bf16) 896 // Matrices A should be stored in row-major and B in column-major. 897 // Only f16/bf16 matrices can be stored in either column-major or row-major 898 // by setting the tranpose value(imm-trans-a,imm-trans-b) in PTX code. 899 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) && 900 (getLayoutA() == mlir::NVVM::MMALayout::col || 901 getLayoutB() == mlir::NVVM::MMALayout::row)) { 902 return emitOpError() 903 << "given layouts layout_a = " << stringifyMMALayout(getLayoutA()) 904 << " and layout_b = " << stringifyMMALayout(getLayoutB()) 905 << " for input types " << stringifyWGMMATypes(typeA) << " and " 906 << stringifyWGMMATypes(typeB) 907 << " requires transpose. However, this is only supported for: " 908 << stringifyMMATypes(MMATypes::f16) << " and " 909 << stringifyMMATypes(MMATypes::bf16); 910 } 911 912 // Check result registers 913 int expectedOutput = 0; 914 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32) 915 expectedOutput = getShape().getN() / 2; 916 if (typeD == WGMMATypes::f16) 917 expectedOutput = getShape().getN() / 4; 918 if (outputSize != expectedOutput) { 919 return emitOpError() << "results " << expectedOutput 920 << ", however output struct has " << outputSize 921 << " elements"; 922 } 923 // Check satfinite (only available for s32 accumulator) 924 if (typeD != WGMMATypes::s32 && 925 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == 926 NVVM::MMAIntOverflow::satfinite) { 927 return emitOpError() 928 << " `satfinite` can be only used with s32 accumulator, however " 929 "the current accumulator is " 930 << NVVM::stringifyWGMMATypes(typeD); 931 } 932 933 return success(); 934 } 935 936 std::string NVVM::WgmmaMmaAsyncOp::getPtx() { 937 938 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK(); 939 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; 940 941 StringRef outputTypeName = stringifyWGMMATypes(getTypeD()); 942 943 int expectedOutputRegisters = 0; 944 if (getTypeD() == WGMMATypes::f16) 945 expectedOutputRegisters = getShape().getN() / 4; 946 else 947 expectedOutputRegisters = getShape().getN() / 2; 948 949 std::string ptx; 950 llvm::raw_string_ostream ss(ptx); 951 952 ss << "{\n" 953 ".reg .pred p;\n" 954 "setp.ne.b32 p, $" 955 << ((expectedOutputRegisters * 2) + 2) 956 << ", 0;\n" 957 "wgmma.mma_async.sync.aligned.m" 958 << m << "n" << n << "k" << k << "." << outputTypeName << "." 959 << stringifyWGMMATypes(getTypeA()) << "." 960 << stringifyWGMMATypes(getTypeB()); 961 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == 962 NVVM::MMAIntOverflow::satfinite) 963 ss << ".satfinite"; 964 ss << " {"; 965 int regCnt = 0; 966 for (; regCnt < expectedOutputRegisters; ++regCnt) { 967 ss << "$" << regCnt; 968 if (regCnt != expectedOutputRegisters - 1) 969 ss << ", "; 970 } 971 972 ss << "},"; 973 // Need to map read/write registers correctly. 974 regCnt = (regCnt * 2); 975 ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p"; 976 if (getTypeD() != WGMMATypes::s32) { 977 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4); 978 } 979 // Don't add transpose parameters unless needed. 980 if (isF16) { 981 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6); 982 } 983 ss << ";\n" 984 << "}\n"; 985 return ptx; 986 } 987 988 void NVVM::WgmmaMmaAsyncOp::getAsmValues( 989 RewriterBase &rewriter, 990 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> 991 &asmValues) { 992 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; 993 if (getResults()) 994 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write}); 995 if (getInouts()) 996 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite}); 997 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read}); 998 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read}); 999 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())), 1000 mlir::NVVM::PTXRegisterMod::Read}); 1001 if (getTypeD() != WGMMATypes::s32) { 1002 asmValues.push_back( 1003 {makeConstantI32(rewriter, 1004 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1), 1005 mlir::NVVM::PTXRegisterMod::Read}); 1006 asmValues.push_back( 1007 {makeConstantI32(rewriter, 1008 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1), 1009 mlir::NVVM::PTXRegisterMod::Read}); 1010 } 1011 if (isF16) { 1012 asmValues.push_back( 1013 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())), 1014 mlir::NVVM::PTXRegisterMod::Read}); 1015 asmValues.push_back( 1016 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())), 1017 mlir::NVVM::PTXRegisterMod::Read}); 1018 } 1019 } 1020 LogicalResult NVVM::FenceProxyOp::verify() { 1021 if (getKind() == NVVM::ProxyKind::TENSORMAP) 1022 return emitOpError() << "tensormap proxy is not a supported proxy kind"; 1023 if (getKind() == NVVM::ProxyKind::GENERIC) 1024 return emitOpError() << "generic proxy not a supported proxy kind"; 1025 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) { 1026 return emitOpError() << "async_shared fence requires space attribute"; 1027 } 1028 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) { 1029 return emitOpError() << "only async_shared fence can have space attribute"; 1030 } 1031 return success(); 1032 } 1033 1034 LogicalResult NVVM::FenceProxyAcquireOp::verify() { 1035 if (getFromProxy() != NVVM::ProxyKind::GENERIC) 1036 return emitOpError("uni-directional proxies only support generic for " 1037 "from_proxy attribute"); 1038 1039 if (getToProxy() != NVVM::ProxyKind::TENSORMAP) 1040 return emitOpError("uni-directional proxies only support tensormap " 1041 "for to_proxy attribute"); 1042 1043 return success(); 1044 } 1045 1046 LogicalResult NVVM::FenceProxyReleaseOp::verify() { 1047 if (getFromProxy() != NVVM::ProxyKind::GENERIC) 1048 return emitOpError("uni-directional proxies only support generic for " 1049 "from_proxy attribute"); 1050 1051 if (getToProxy() != NVVM::ProxyKind::TENSORMAP) 1052 return emitOpError("uni-directional proxies only support tensormap " 1053 "for to_proxy attribute"); 1054 1055 return success(); 1056 } 1057 1058 LogicalResult NVVM::SetMaxRegisterOp::verify() { 1059 if (getRegCount() % 8) 1060 return emitOpError("new register size must be multiple of 8"); 1061 if (getRegCount() < 24 || getRegCount() > 256) 1062 return emitOpError("new register size must be in between 24 to 256"); 1063 return success(); 1064 } 1065 1066 LogicalResult NVVM::BarrierOp::verify() { 1067 if (getNumberOfThreads() && !getBarrierId()) 1068 return emitOpError( 1069 "barrier id is missing, it should be set between 0 to 15"); 1070 return success(); 1071 } 1072 1073 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims, 1074 bool isIm2Col) { 1075 switch (tensorDims) { 1076 case 1: 1077 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d; 1078 case 2: 1079 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d; 1080 case 3: 1081 return isIm2Col 1082 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d 1083 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d; 1084 case 4: 1085 return isIm2Col 1086 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d 1087 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d; 1088 case 5: 1089 return isIm2Col 1090 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d 1091 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d; 1092 default: 1093 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp."); 1094 } 1095 } 1096 1097 //===----------------------------------------------------------------------===// 1098 // NVVMDialect initialization, type parsing, and registration. 1099 //===----------------------------------------------------------------------===// 1100 1101 // TODO: This should be the llvm.nvvm dialect once this is supported. 1102 void NVVMDialect::initialize() { 1103 addOperations< 1104 #define GET_OP_LIST 1105 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" 1106 >(); 1107 addAttributes< 1108 #define GET_ATTRDEF_LIST 1109 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" 1110 >(); 1111 1112 // Support unknown operations because not all NVVM operations are 1113 // registered. 1114 allowUnknownOperations(); 1115 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>(); 1116 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>(); 1117 } 1118 1119 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, 1120 NamedAttribute attr) { 1121 StringAttr attrName = attr.getName(); 1122 // Kernel function attribute should be attached to functions. 1123 if (attrName == NVVMDialect::getKernelFuncAttrName()) { 1124 if (!isa<LLVM::LLVMFuncOp>(op)) { 1125 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() 1126 << "' attribute attached to unexpected op"; 1127 } 1128 } 1129 // If maxntid and reqntid exist, it must be an array with max 3 dim 1130 if (attrName == NVVMDialect::getMaxntidAttrName() || 1131 attrName == NVVMDialect::getReqntidAttrName()) { 1132 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue()); 1133 if (!values || values.empty() || values.size() > 3) 1134 return op->emitError() 1135 << "'" << attrName 1136 << "' attribute must be integer array with maximum 3 index"; 1137 } 1138 // If minctasm and maxnreg exist, it must be an integer attribute 1139 if (attrName == NVVMDialect::getMinctasmAttrName() || 1140 attrName == NVVMDialect::getMaxnregAttrName()) { 1141 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) 1142 return op->emitError() 1143 << "'" << attrName << "' attribute must be integer constant"; 1144 } 1145 1146 return success(); 1147 } 1148 1149 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op, 1150 unsigned regionIndex, 1151 unsigned argIndex, 1152 NamedAttribute argAttr) { 1153 auto funcOp = dyn_cast<FunctionOpInterface>(op); 1154 if (!funcOp) 1155 return success(); 1156 1157 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName()); 1158 StringAttr attrName = argAttr.getName(); 1159 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) { 1160 if (!isKernel) { 1161 return op->emitError() 1162 << "'" << attrName 1163 << "' attribute must be present only on kernel arguments"; 1164 } 1165 if (!isa<UnitAttr>(argAttr.getValue())) 1166 return op->emitError() << "'" << attrName << "' must be a unit attribute"; 1167 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) { 1168 return op->emitError() 1169 << "'" << attrName 1170 << "' attribute requires the argument to also have attribute '" 1171 << LLVM::LLVMDialect::getByValAttrName() << "'"; 1172 } 1173 } 1174 1175 return success(); 1176 } 1177 1178 //===----------------------------------------------------------------------===// 1179 // NVVM target attribute. 1180 //===----------------------------------------------------------------------===// 1181 LogicalResult 1182 NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1183 int optLevel, StringRef triple, StringRef chip, 1184 StringRef features, DictionaryAttr flags, 1185 ArrayAttr files) { 1186 if (optLevel < 0 || optLevel > 3) { 1187 emitError() << "The optimization level must be a number between 0 and 3."; 1188 return failure(); 1189 } 1190 if (triple.empty()) { 1191 emitError() << "The target triple cannot be empty."; 1192 return failure(); 1193 } 1194 if (chip.empty()) { 1195 emitError() << "The target chip cannot be empty."; 1196 return failure(); 1197 } 1198 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) { 1199 return attr && mlir::isa<StringAttr>(attr); 1200 })) { 1201 emitError() << "All the elements in the `link` array must be strings."; 1202 return failure(); 1203 } 1204 return success(); 1205 } 1206 1207 #define GET_OP_CLASSES 1208 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" 1209 1210 #define GET_ATTRDEF_CLASSES 1211 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" 1212