1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the NVVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 // The NVVM dialect only contains GPU specific additions on top of the general 13 // LLVM dialect. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 18 19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 20 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/IR/BuiltinAttributes.h" 24 #include "mlir/IR/BuiltinTypes.h" 25 #include "mlir/IR/Diagnostics.h" 26 #include "mlir/IR/DialectImplementation.h" 27 #include "mlir/IR/MLIRContext.h" 28 #include "mlir/IR/Operation.h" 29 #include "mlir/IR/OperationSupport.h" 30 #include "mlir/IR/Types.h" 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/AsmParser/Parser.h" 34 #include "llvm/IR/Attributes.h" 35 #include "llvm/IR/Function.h" 36 #include "llvm/IR/Type.h" 37 #include "llvm/Support/Casting.h" 38 #include "llvm/Support/SourceMgr.h" 39 #include "llvm/Support/raw_ostream.h" 40 #include <cassert> 41 #include <optional> 42 #include <string> 43 44 using namespace mlir; 45 using namespace NVVM; 46 47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc" 48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc" 49 50 //===----------------------------------------------------------------------===// 51 // Printing/parsing for NVVM ops 52 //===----------------------------------------------------------------------===// 53 54 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) { 55 p << " " << op->getOperands(); 56 if (op->getNumResults() > 0) 57 p << " : " << op->getResultTypes(); 58 } 59 60 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type 61 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) { 62 MLIRContext *context = parser.getContext(); 63 auto int32Ty = IntegerType::get(context, 32); 64 auto int1Ty = IntegerType::get(context, 1); 65 66 SmallVector<OpAsmParser::UnresolvedOperand, 8> ops; 67 Type type; 68 return failure(parser.parseOperandList(ops) || 69 parser.parseOptionalAttrDict(result.attributes) || 70 parser.parseColonType(type) || 71 parser.addTypeToList(type, result.types) || 72 parser.resolveOperands(ops, {int32Ty, int1Ty}, 73 parser.getNameLoc(), result.operands)); 74 } 75 76 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); } 77 78 // This verifier is shared among the following Ops: 79 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) 80 // CpAsyncBulkTensorPrefetchOp (TMA Prefetch) 81 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce) 82 static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims, 83 bool isIm2Col, 84 size_t numIm2ColOffsets, 85 Location loc) { 86 if (tensorDims < 1 || tensorDims > 5) 87 return emitError(loc, "expects coordinates between 1 to 5 dimension"); 88 89 // For Im2Col mode, there are two constraints: 90 if (isIm2Col) { 91 // 1. Tensor must always be at least 3-d. 92 if (tensorDims < 3) 93 return emitError( 94 loc, 95 "to use im2col mode, the tensor has to be at least 3-dimensional"); 96 // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number. 97 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2))) 98 return emitError( 99 loc, "im2col offsets must be 2 less than number of coordinates"); 100 } 101 return success(); 102 } 103 104 LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() { 105 size_t numIm2ColOffsets = getIm2colOffsets().size(); 106 bool isIm2Col = numIm2ColOffsets > 0; 107 return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 108 numIm2ColOffsets, getLoc()); 109 } 110 111 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() { 112 if (getCoordinates().size() > 5) 113 return emitError("Maximum 5 coordinates and dimension is supported."); 114 return success(); 115 } 116 117 LogicalResult CpAsyncOp::verify() { 118 if (getModifier() != LoadCacheModifierKind::CG && 119 getModifier() != LoadCacheModifierKind::CA) 120 return emitError("Only CG and CA cache modifiers are supported."); 121 if (getSize() != 4 && getSize() != 8 && getSize() != 16) 122 return emitError("expected byte size to be either 4, 8 or 16."); 123 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16) 124 return emitError("CG cache modifier is only support for 16 bytes copy."); 125 return success(); 126 } 127 128 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() { 129 size_t numIm2ColOffsets = getIm2colOffsets().size(); 130 bool isIm2Col = numIm2ColOffsets > 0; 131 return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 132 numIm2ColOffsets, getLoc()); 133 } 134 135 LogicalResult CpAsyncBulkTensorReduceOp::verify() { 136 bool isIm2Col = (getMode() == TMAStoreMode::IM2COL); 137 return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0, 138 getLoc()); 139 } 140 141 // Given the element type of an operand and whether or not it is an accumulator, 142 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the 143 // operand's element type. 144 std::optional<mlir::NVVM::MMATypes> 145 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) { 146 auto half2Type = 147 LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2); 148 if (operandElType.isF64()) 149 return NVVM::MMATypes::f64; 150 if (operandElType.isF16() || operandElType == half2Type) 151 return NVVM::MMATypes::f16; 152 if (operandElType.isF32() && isAccumulator) 153 return NVVM::MMATypes::f32; 154 if (operandElType.isF32() && !isAccumulator) 155 return NVVM::MMATypes::tf32; 156 if (llvm::isa<IntegerType>(operandElType)) { 157 if (isAccumulator) 158 return NVVM::MMATypes::s32; 159 return std::nullopt; 160 } 161 162 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) { 163 if (structType.getBody().empty()) 164 return std::nullopt; 165 return inferOperandMMAType(structType.getBody()[0], isAccumulator); 166 } 167 168 return std::nullopt; 169 } 170 171 static bool isInt4PtxType(MMATypes type) { 172 return (type == MMATypes::u4 || type == MMATypes::s4); 173 } 174 175 static bool isInt8PtxType(MMATypes type) { 176 return (type == MMATypes::u8 || type == MMATypes::s8); 177 } 178 179 static bool isIntegerPtxType(MMATypes type) { 180 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 || 181 type == MMATypes::s32; 182 } 183 184 MMATypes MmaOp::accumPtxType() { 185 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType( 186 getODSOperands(2).getTypes().front(), /*isAccum=*/true); 187 assert(val.has_value() && "accumulator PTX type should always be inferrable"); 188 return val.value(); 189 } 190 191 MMATypes MmaOp::resultPtxType() { 192 std::optional<mlir::NVVM::MMATypes> val = 193 inferOperandMMAType(getResult().getType(), /*isAccum=*/true); 194 assert(val.has_value() && "result PTX type should always be inferrable"); 195 return val.value(); 196 } 197 198 void MmaOp::print(OpAsmPrinter &p) { 199 SmallVector<Type, 4> regTypes; 200 struct OperandFragment { 201 StringRef operandName; 202 StringRef ptxTypeAttr; 203 SmallVector<Value, 4> regs; 204 explicit OperandFragment(StringRef name, StringRef ptxTypeName) 205 : operandName(name), ptxTypeAttr(ptxTypeName) {} 206 }; 207 208 std::array<OperandFragment, 3> frags{ 209 OperandFragment("A", getMultiplicandAPtxTypeAttrName()), 210 OperandFragment("B", getMultiplicandBPtxTypeAttrName()), 211 OperandFragment("C", "")}; 212 SmallVector<StringRef, 4> ignoreAttrNames{ 213 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()}; 214 215 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) { 216 auto &frag = frags[fragIdx]; 217 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx); 218 for (auto operandIdx = varOperandSpec.first; 219 operandIdx < varOperandSpec.first + varOperandSpec.second; 220 operandIdx++) { 221 frag.regs.push_back(this->getOperand(operandIdx)); 222 if (operandIdx == 0) { 223 regTypes.push_back(this->getOperand(operandIdx).getType()); 224 } 225 } 226 std::optional<MMATypes> inferredType = 227 inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2); 228 if (inferredType) 229 ignoreAttrNames.push_back(frag.ptxTypeAttr); 230 } 231 232 auto printMmaOperand = [&](const OperandFragment &frag) -> void { 233 p << " " << frag.operandName; 234 p << "["; 235 p.printOperands(frag.regs); 236 p << "] "; 237 }; 238 239 for (const auto &frag : frags) { 240 printMmaOperand(frag); 241 } 242 243 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); 244 245 // Print the types of the operands and result. 246 p << " : " << "("; 247 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(), 248 frags[1].regs[0].getType(), 249 frags[2].regs[0].getType()}, 250 p); 251 p << ")"; 252 p.printArrowTypeList(TypeRange{this->getRes().getType()}); 253 } 254 255 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType, 256 ValueRange operandA, ValueRange operandB, ValueRange operandC, 257 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op, 258 std::optional<MMAIntOverflow> intOverflow, 259 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes, 260 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) { 261 262 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); 263 MLIRContext *ctx = builder.getContext(); 264 result.addAttribute( 265 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2])); 266 267 result.addOperands(operandA); 268 result.addOperands(operandB); 269 result.addOperands(operandC); 270 271 if (multiplicandPtxTypes) { 272 result.addAttribute("multiplicandAPtxType", 273 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); 274 result.addAttribute("multiplicandBPtxType", 275 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); 276 } else { 277 if (auto res = inferOperandMMAType(operandA[0].getType(), false)) 278 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); 279 if (auto res = inferOperandMMAType(operandB[0].getType(), false)) 280 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); 281 } 282 283 if (multiplicandLayouts) { 284 result.addAttribute("layoutA", 285 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0])); 286 result.addAttribute("layoutB", 287 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1])); 288 } else { 289 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row)); 290 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col)); 291 } 292 293 if (intOverflow.has_value()) 294 result.addAttribute("intOverflowBehavior", 295 MMAIntOverflowAttr::get(ctx, *intOverflow)); 296 if (b1Op.has_value()) 297 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op)); 298 299 result.addTypes(resultType); 300 result.addAttribute( 301 MmaOp::getOperandSegmentSizeAttr(), 302 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()), 303 static_cast<int32_t>(operandB.size()), 304 static_cast<int32_t>(operandC.size())})); 305 } 306 307 // <operation> := 308 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]` 309 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0])) 310 // `->` type($res) 311 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) { 312 struct OperandFragment { 313 std::optional<MMATypes> elemtype; 314 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs; 315 SmallVector<Type> regTypes; 316 }; 317 318 Builder &builder = parser.getBuilder(); 319 std::array<OperandFragment, 4> frags; 320 321 NamedAttrList namedAttributes; 322 323 // A helper to parse the operand segments. 324 auto parseMmaOperand = [&](StringRef operandName, 325 OperandFragment &frag) -> LogicalResult { 326 if (parser.parseKeyword(operandName).failed()) 327 return failure(); 328 if (parser 329 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare) 330 .failed()) 331 return failure(); 332 return success(); 333 }; 334 335 // Parse the operand segments. 336 if (parseMmaOperand("A", frags[0]).failed()) 337 return failure(); 338 if (parseMmaOperand("B", frags[1]).failed()) 339 return failure(); 340 if (parseMmaOperand("C", frags[2]).failed()) 341 return failure(); 342 343 if (parser.parseOptionalAttrDict(namedAttributes).failed()) 344 return failure(); 345 346 // Parse the type specification and resolve operands. 347 SmallVector<Type, 3> operandTypes; 348 if (failed(parser.parseColon())) 349 return failure(); 350 if (failed(parser.parseLParen())) 351 return failure(); 352 if (failed(parser.parseTypeList(operandTypes))) 353 return failure(); 354 if (failed(parser.parseRParen())) 355 if (operandTypes.size() != 3) 356 return parser.emitError( 357 parser.getNameLoc(), 358 "expected one type for each operand segment but got " + 359 Twine(operandTypes.size()) + " types"); 360 for (const auto &iter : llvm::enumerate(operandTypes)) { 361 auto &frag = frags[iter.index()]; 362 frag.regTypes.resize(frag.regs.size(), iter.value()); 363 if (failed(parser.resolveOperands(frag.regs, frag.regTypes, 364 parser.getNameLoc(), result.operands))) 365 return failure(); 366 frag.elemtype = 367 inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2); 368 } 369 370 Type resultType; 371 if (parser.parseArrow() || parser.parseType(resultType)) 372 return failure(); 373 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true); 374 375 std::array<StringRef, 2> names{"multiplicandAPtxType", 376 "multiplicandBPtxType"}; 377 for (unsigned idx = 0; idx < names.size(); idx++) { 378 const auto &frag = frags[idx]; 379 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]); 380 if (!frag.elemtype.has_value() && !attr.has_value()) { 381 return parser.emitError( 382 parser.getNameLoc(), 383 "attribute " + names[idx] + 384 " is not provided explicitly and cannot be inferred"); 385 } 386 if (!attr.has_value()) 387 result.addAttribute( 388 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype)); 389 } 390 391 result.addTypes(resultType); 392 if (!namedAttributes.empty()) 393 result.addAttributes(namedAttributes); 394 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(), 395 builder.getDenseI32ArrayAttr({ 396 static_cast<int32_t>(frags[0].regs.size()), 397 static_cast<int32_t>(frags[1].regs.size()), 398 static_cast<int32_t>(frags[2].regs.size()), 399 })); 400 return success(); 401 } 402 403 LogicalResult MmaOp::verify() { 404 MLIRContext *context = getContext(); 405 auto f16Ty = Float16Type::get(context); 406 auto i32Ty = IntegerType::get(context, 32); 407 auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2); 408 auto f32Ty = Float32Type::get(context); 409 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( 410 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); 411 412 auto s32x4StructTy = 413 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty}); 414 auto f32x8StructTy = 415 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty)); 416 auto f16x2x2StructTy = 417 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty}); 418 auto f32x4StructTy = 419 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty}); 420 auto s32x2StructTy = 421 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); 422 423 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(), 424 getShapeAttr().getK()}; 425 426 // These variables define the set of allowed data types for matrices A, B, C, 427 // and result. 428 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>; 429 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>; 430 AllowedShapes allowedShapes; 431 AllowedTypes expectedA; 432 AllowedTypes expectedB; 433 AllowedTypes expectedC; 434 SmallVector<Type> expectedResult; 435 436 // When M = 16, we just need to calculate the number of 8xk tiles, where 437 // k is a factor that depends on the data type. 438 if (mmaShape[0] == 16) { 439 int64_t kFactor; 440 Type multiplicandFragType; 441 switch (*getMultiplicandAPtxType()) { 442 case MMATypes::tf32: 443 kFactor = 4; 444 multiplicandFragType = i32Ty; 445 expectedResult.push_back(LLVM::LLVMStructType::getLiteral( 446 context, {f32Ty, f32Ty, f32Ty, f32Ty})); 447 break; 448 case MMATypes::bf16: 449 kFactor = 8; 450 multiplicandFragType = i32Ty; 451 expectedResult.push_back(LLVM::LLVMStructType::getLiteral( 452 context, {f32Ty, f32Ty, f32Ty, f32Ty})); 453 break; 454 case MMATypes::f16: 455 kFactor = 8; 456 multiplicandFragType = f16x2Ty; 457 expectedResult.push_back(f16x2x2StructTy); 458 expectedResult.push_back(f32x4StructTy); 459 break; 460 case MMATypes::s4: 461 case MMATypes::u4: 462 kFactor = 32; 463 break; 464 case MMATypes::b1: 465 kFactor = 128; 466 break; 467 case MMATypes::s8: 468 case MMATypes::u8: 469 kFactor = 16; 470 break; 471 default: 472 return emitError("invalid shape or multiplicand type: " + 473 stringifyEnum(getMultiplicandAPtxType().value())); 474 } 475 476 if (isIntegerPtxType(getMultiplicandAPtxType().value())) { 477 expectedResult.push_back(s32x4StructTy); 478 expectedC.emplace_back(4, i32Ty); 479 multiplicandFragType = i32Ty; 480 } else { 481 expectedC.emplace_back(2, f16x2Ty); 482 expectedC.emplace_back(4, f32Ty); 483 } 484 485 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor); 486 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor); 487 expectedA.emplace_back(unitA, multiplicandFragType); 488 expectedB.emplace_back(unitB, multiplicandFragType); 489 allowedShapes.push_back({16, 8, kFactor}); 490 allowedShapes.push_back({16, 8, kFactor * 2}); 491 } 492 493 // In the M=8 case, there is only 1 possible case per data type. 494 if (mmaShape[0] == 8) { 495 if (*getMultiplicandAPtxType() == MMATypes::f16) { 496 expectedA.emplace_back(2, f16x2Ty); 497 expectedB.emplace_back(2, f16x2Ty); 498 expectedResult.push_back(f16x2x4StructTy); 499 expectedResult.push_back(f32x8StructTy); 500 expectedC.emplace_back(4, f16x2Ty); 501 expectedC.emplace_back(8, f32Ty); 502 allowedShapes.push_back({8, 8, 4}); 503 } 504 if (*getMultiplicandAPtxType() == MMATypes::f64) { 505 Type f64Ty = Float64Type::get(context); 506 expectedA.emplace_back(1, f64Ty); 507 expectedB.emplace_back(1, f64Ty); 508 expectedC.emplace_back(2, f64Ty); 509 // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2)); 510 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( 511 context, SmallVector<Type>(2, f64Ty))); 512 allowedShapes.push_back({8, 8, 4}); 513 } 514 if (isIntegerPtxType(getMultiplicandAPtxType().value())) { 515 expectedA.push_back({i32Ty}); 516 expectedB.push_back({i32Ty}); 517 expectedC.push_back({i32Ty, i32Ty}); 518 expectedResult.push_back(s32x2StructTy); 519 if (isInt4PtxType(getMultiplicandAPtxType().value())) 520 allowedShapes.push_back({8, 8, 32}); 521 if (isInt8PtxType(getMultiplicandAPtxType().value())) 522 allowedShapes.push_back({8, 8, 16}); 523 if (getMultiplicandAPtxType().value() == MMATypes::b1) 524 allowedShapes.push_back({8, 8, 128}); 525 } 526 } 527 528 std::string errorMessage; 529 llvm::raw_string_ostream errorStream(errorMessage); 530 531 // Check that we matched an existing shape/dtype combination. 532 if (expectedA.empty() || expectedB.empty() || expectedC.empty() || 533 !llvm::is_contained(allowedShapes, mmaShape)) { 534 errorStream << "unimplemented variant for MMA shape <"; 535 llvm::interleaveComma(mmaShape, errorStream); 536 errorStream << ">"; 537 return emitOpError(errorMessage); 538 } 539 540 // Verify the operand types for segments of A, B, and C operands. 541 std::array<StringRef, 3> operandNames{"A", "B", "C"}; 542 for (const auto &iter : llvm::enumerate( 543 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) { 544 auto spec = this->getODSOperandIndexAndLength(iter.index()); 545 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first, 546 operand_type_begin() + spec.first + 547 spec.second); 548 bool match = llvm::is_contained(iter.value(), operandTySeg); 549 550 if (!match) { 551 errorStream << "Could not match types for the " 552 << operandNames[iter.index()] 553 << " operands; expected one of "; 554 for (const auto &x : iter.value()) { 555 errorStream << x.size() << "x" << x[0] << " "; 556 } 557 errorStream << "but got "; 558 llvm::interleaveComma(operandTySeg, errorStream); 559 return emitOpError(errorMessage); 560 } 561 } 562 563 // Check the result type 564 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) { 565 return expectedResultType == getResult().getType(); 566 })) { 567 errorStream 568 << "Could not match allowed types for the result; expected one of "; 569 llvm::interleaveComma(expectedResult, errorStream); 570 errorStream << " but got " << getResult().getType(); 571 return emitOpError(errorMessage); 572 } 573 574 // Ensure that binary MMA variants have a b1 MMA operation defined. 575 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) { 576 return emitOpError("op requires " + getB1OpAttrName().strref() + 577 " attribute"); 578 } 579 580 // Ensure int4/int8 MMA variants specify the accum overflow behavior 581 // attribute. 582 if (isInt4PtxType(*getMultiplicandAPtxType()) || 583 isInt8PtxType(*getMultiplicandAPtxType())) { 584 if (!getIntOverflowBehavior()) 585 return emitOpError("op requires " + 586 getIntOverflowBehaviorAttrName().strref() + 587 " attribute"); 588 } 589 590 return success(); 591 } 592 593 LogicalResult ShflOp::verify() { 594 if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid")) 595 return success(); 596 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); 597 auto elementType = (type && type.getBody().size() == 2) 598 ? llvm::dyn_cast<IntegerType>(type.getBody()[1]) 599 : nullptr; 600 if (!elementType || elementType.getWidth() != 1) 601 return emitError("expected return type to be a two-element struct with " 602 "i1 as the second element"); 603 return success(); 604 } 605 606 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type, 607 NVVM::MMAFrag frag, int nRow, 608 int nCol, 609 MLIRContext *context) { 610 unsigned numberElements = 0; 611 Type elementType; 612 OpBuilder builder(context); 613 Type f16x2 = VectorType::get(2, builder.getF16Type()); 614 if (type == NVVM::MMATypes::f16) { 615 elementType = f16x2; 616 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) 617 numberElements = 8; 618 else 619 numberElements = 4; 620 } else if (type == NVVM::MMATypes::f32) { 621 elementType = builder.getF32Type(); 622 numberElements = 8; 623 } else if (type == NVVM::MMATypes::tf32) { 624 elementType = builder.getI32Type(); 625 numberElements = 4; 626 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) { 627 elementType = builder.getI32Type(); 628 int parallelSize = 0; 629 if (frag == NVVM::MMAFrag::a) 630 parallelSize = nRow; 631 if (frag == NVVM::MMAFrag::b) 632 parallelSize = nCol; 633 634 // m == 16 && n == 16 && k == 16 635 if (parallelSize == 16) 636 numberElements = 2; 637 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16 638 else if (parallelSize == 8) 639 numberElements = 1; 640 else if (parallelSize == 32) 641 numberElements = 4; 642 } else if (type == NVVM::MMATypes::s32) { 643 elementType = builder.getI32Type(); 644 numberElements = 8; 645 } 646 assert(numberElements != 0 && elementType != nullptr); 647 return std::make_pair(elementType, numberElements); 648 } 649 650 static std::pair<mlir::Type, unsigned> 651 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, 652 int k, MLIRContext *context) { 653 int nRow, nCol; 654 if (frag == NVVM::MMAFrag::a) { 655 nRow = m; 656 nCol = k; 657 } else if (frag == NVVM::MMAFrag::b) { 658 nRow = k; 659 nCol = n; 660 } else { 661 nRow = m; 662 nCol = n; 663 } 664 assert(nRow && nCol); 665 return inferMMAType(type, frag, nRow, nCol, context); 666 } 667 668 LogicalResult NVVM::WMMALoadOp::verify() { 669 unsigned addressSpace = 670 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 671 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace && 672 addressSpace != NVVM::kSharedMemorySpace) 673 return emitOpError("expected source pointer in memory " 674 "space 0, 1, 3"); 675 676 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), 677 getEltype(), getFrag()) == 0) 678 return emitOpError() << "invalid attribute combination"; 679 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( 680 getEltype(), getFrag(), getM(), getN(), getK(), getContext()); 681 Type dstType = LLVM::LLVMStructType::getLiteral( 682 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); 683 if (getType() != dstType) 684 return emitOpError("expected destination type is a structure of ") 685 << typeInfo.second << " elements of type " << typeInfo.first; 686 return success(); 687 } 688 689 LogicalResult NVVM::WMMAStoreOp::verify() { 690 unsigned addressSpace = 691 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 692 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace && 693 addressSpace != NVVM::kSharedMemorySpace) 694 return emitOpError("expected operands to be a source pointer in memory " 695 "space 0, 1, 3"); 696 697 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), 698 getEltype()) == 0) 699 return emitOpError() << "invalid attribute combination"; 700 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( 701 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); 702 if (getArgs().size() != typeInfo.second) 703 return emitOpError() << "expected " << typeInfo.second << " data operands"; 704 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) { 705 return operands.getType() != typeInfo.first; 706 })) 707 return emitOpError() << "expected data operands of type " << typeInfo.first; 708 return success(); 709 } 710 711 LogicalResult NVVM::WMMAMmaOp::verify() { 712 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(), 713 getLayoutB(), getEltypeA(), 714 getEltypeB()) == 0) 715 return emitOpError() << "invalid attribute combination"; 716 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK( 717 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext()); 718 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK( 719 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext()); 720 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK( 721 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); 722 SmallVector<Type, 32> arguments; 723 arguments.append(typeInfoA.second, typeInfoA.first); 724 arguments.append(typeInfoB.second, typeInfoB.first); 725 arguments.append(typeInfoC.second, typeInfoC.first); 726 unsigned numArgs = arguments.size(); 727 if (getArgs().size() != numArgs) 728 return emitOpError() << "expected " << numArgs << " arguments"; 729 for (unsigned i = 0; i < numArgs; i++) { 730 if (getArgs()[i].getType() != arguments[i]) 731 return emitOpError() << "expected argument " << i << " to be of type " 732 << arguments[i]; 733 } 734 Type dstType = LLVM::LLVMStructType::getLiteral( 735 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first)); 736 if (getType() != dstType) 737 return emitOpError("expected destination type is a structure of ") 738 << typeInfoC.second << " elements of type " << typeInfoC.first; 739 return success(); 740 } 741 742 LogicalResult NVVM::LdMatrixOp::verify() { 743 unsigned addressSpace = 744 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 745 if (addressSpace != NVVM::kSharedMemorySpace) 746 return emitOpError("expected source pointer in memory space 3"); 747 748 if (getNum() != 1 && getNum() != 2 && getNum() != 4) 749 return emitOpError("expected num attribute to be 1, 2 or 4"); 750 751 Type i32 = IntegerType::get(getContext(), 32); 752 if (getNum() == 1 && getType() != i32) 753 return emitOpError("expected destination type is i32"); 754 if (getNum() == 2 || getNum() == 4) { 755 Type dstType = LLVM::LLVMStructType::getLiteral( 756 getContext(), SmallVector<Type>(getNum(), i32)); 757 if (getType() != dstType) 758 return emitOpError("expected destination type is a structure of ") 759 << getNum() << " elements of type i32"; 760 } 761 return success(); 762 } 763 764 LogicalResult NVVM::StMatrixOp::verify() { 765 unsigned addressSpace = 766 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 767 if (addressSpace != NVVM::kSharedMemorySpace) 768 return emitOpError("expected source pointer in memory space 3"); 769 770 int numMatrix = getSources().size(); 771 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4) 772 return emitOpError("expected num attribute to be 1, 2 or 4"); 773 774 return success(); 775 } 776 777 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) { 778 if (typeA == NVVM::WGMMATypes::tf32) 779 return 8; 780 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16) 781 return 16; 782 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8) 783 return 32; 784 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2) 785 return 32; 786 if (typeA == NVVM::WGMMATypes::b1) 787 return 256; 788 return failure(); 789 } 790 791 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, 792 NVVM::WGMMATypes typeA, 793 NVVM::WGMMATypes typeB) { 794 switch (typeA) { 795 case NVVM::WGMMATypes::f16: 796 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 797 typeB == NVVM::WGMMATypes::f16) 798 return success(); 799 break; 800 case NVVM::WGMMATypes::tf32: 801 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32) 802 return success(); 803 break; 804 case NVVM::WGMMATypes::u8: 805 case NVVM::WGMMATypes::s8: 806 if (typeD == NVVM::WGMMATypes::s32 && 807 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8)) 808 return success(); 809 break; 810 case NVVM::WGMMATypes::b1: 811 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1) 812 return success(); 813 break; 814 case NVVM::WGMMATypes::bf16: 815 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 816 typeB == NVVM::WGMMATypes::bf16) 817 return success(); 818 break; 819 case NVVM::WGMMATypes::e4m3: 820 case NVVM::WGMMATypes::e5m2: 821 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 822 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3)) 823 return success(); 824 break; 825 case WGMMATypes::f32: 826 case WGMMATypes::s32: 827 llvm_unreachable("unsupported input types"); 828 break; 829 } 830 return failure(); 831 } 832 833 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) { 834 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64, 835 72, 80, 88, 96, 104, 112, 120, 128, 836 136, 144, 152, 160, 168, 176, 184, 192, 837 200, 208, 216, 224, 232, 240, 248, 256}; 838 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64, 839 80, 96, 112, 128, 144, 160, 840 176, 192, 208, 224, 240, 256}; 841 switch (typeA) { 842 case WGMMATypes::f16: 843 case WGMMATypes::tf32: 844 case WGMMATypes::bf16: 845 case WGMMATypes::e4m3: 846 case WGMMATypes::e5m2: 847 if (llvm::is_contained(allowedN, sizeN)) 848 return success(); 849 break; 850 case WGMMATypes::u8: 851 case WGMMATypes::s8: 852 case WGMMATypes::b1: 853 if (llvm::is_contained(allowedNshort, sizeN)) 854 return success(); 855 break; 856 case WGMMATypes::f32: 857 case WGMMATypes::s32: 858 llvm_unreachable("unsupported input types"); 859 break; 860 } 861 return failure(); 862 } 863 864 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() { 865 Value outValue = getResults(); 866 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType()); 867 if (!stype) 868 return emitOpError() << "expected results to be struct"; 869 int outputSize = stype.getBody().size(); 870 WGMMATypes typeD = getTypeD(); 871 WGMMATypes typeA = getTypeA(); 872 WGMMATypes typeB = getTypeB(); 873 874 for (Type t : stype.getBody()) { 875 if (t != stype.getBody().front()) 876 return emitOpError() 877 << "all elements in struct must be same type but there is " << t; 878 } 879 880 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 && 881 typeD != WGMMATypes::s32) { 882 return emitOpError() << "does not support the given output type " 883 << NVVM::stringifyWGMMATypes(typeD); 884 } 885 if (typeD == WGMMATypes::s32 && 886 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) { 887 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg"; 888 } 889 890 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) { 891 return emitOpError() << NVVM::stringifyWGMMATypes(typeD) 892 << " += " << NVVM::stringifyWGMMATypes(typeA) << " * " 893 << NVVM::stringifyWGMMATypes(typeB) 894 << ", it is not supported."; 895 } 896 897 // Check M 898 if (getShape().getM() != 64) 899 return emitOpError() << "shape 'm' must be 64"; 900 901 // Check K 902 FailureOr<int> allowedK = getAllowedSizeK(typeA); 903 if (failed(allowedK) || allowedK.value() != getShape().getK()) 904 return emitOpError() << "shape 'k' must be " << allowedK.value() 905 << " for input type " 906 << NVVM::stringifyWGMMATypes(typeA); 907 908 // Check N 909 if (failed(isAllowedSizeN(getShape().getN(), typeA))) { 910 return emitOpError() << "has input type " 911 << NVVM::stringifyWGMMATypes(typeA) << " n is set to " 912 << getShape().getN() << ", it is not supported."; 913 } 914 915 // Check transpose (only available for f16/bf16) 916 // Matrices A should be stored in row-major and B in column-major. 917 // Only f16/bf16 matrices can be stored in either column-major or row-major 918 // by setting the tranpose value(imm-trans-a,imm-trans-b) in PTX code. 919 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) && 920 (getLayoutA() == mlir::NVVM::MMALayout::col || 921 getLayoutB() == mlir::NVVM::MMALayout::row)) { 922 return emitOpError() 923 << "given layouts layout_a = " << stringifyMMALayout(getLayoutA()) 924 << " and layout_b = " << stringifyMMALayout(getLayoutB()) 925 << " for input types " << stringifyWGMMATypes(typeA) << " and " 926 << stringifyWGMMATypes(typeB) 927 << " requires transpose. However, this is only supported for: " 928 << stringifyMMATypes(MMATypes::f16) << " and " 929 << stringifyMMATypes(MMATypes::bf16); 930 } 931 932 // Check result registers 933 int expectedOutput = 0; 934 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32) 935 expectedOutput = getShape().getN() / 2; 936 if (typeD == WGMMATypes::f16) 937 expectedOutput = getShape().getN() / 4; 938 if (outputSize != expectedOutput) { 939 return emitOpError() << "results " << expectedOutput 940 << ", however output struct has " << outputSize 941 << " elements"; 942 } 943 // Check satfinite (only available for s32 accumulator) 944 if (typeD != WGMMATypes::s32 && 945 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == 946 NVVM::MMAIntOverflow::satfinite) { 947 return emitOpError() 948 << " `satfinite` can be only used with s32 accumulator, however " 949 "the current accumulator is " 950 << NVVM::stringifyWGMMATypes(typeD); 951 } 952 953 return success(); 954 } 955 956 std::string NVVM::WgmmaMmaAsyncOp::getPtx() { 957 958 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK(); 959 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; 960 961 StringRef outputTypeName = stringifyWGMMATypes(getTypeD()); 962 963 int expectedOutputRegisters = 0; 964 if (getTypeD() == WGMMATypes::f16) 965 expectedOutputRegisters = getShape().getN() / 4; 966 else 967 expectedOutputRegisters = getShape().getN() / 2; 968 969 std::string ptx; 970 llvm::raw_string_ostream ss(ptx); 971 972 ss << "{\n" 973 ".reg .pred p;\n" 974 "setp.ne.b32 p, $" 975 << ((expectedOutputRegisters * 2) + 2) 976 << ", 0;\n" 977 "wgmma.mma_async.sync.aligned.m" 978 << m << "n" << n << "k" << k << "." << outputTypeName << "." 979 << stringifyWGMMATypes(getTypeA()) << "." 980 << stringifyWGMMATypes(getTypeB()); 981 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == 982 NVVM::MMAIntOverflow::satfinite) 983 ss << ".satfinite"; 984 ss << " {"; 985 int regCnt = 0; 986 for (; regCnt < expectedOutputRegisters; ++regCnt) { 987 ss << "$" << regCnt; 988 if (regCnt != expectedOutputRegisters - 1) 989 ss << ", "; 990 } 991 992 ss << "},"; 993 // Need to map read/write registers correctly. 994 regCnt = (regCnt * 2); 995 ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p"; 996 if (getTypeD() != WGMMATypes::s32) { 997 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4); 998 } 999 // Don't add transpose parameters unless needed. 1000 if (isF16) { 1001 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6); 1002 } 1003 ss << ";\n" 1004 << "}\n"; 1005 return ptx; 1006 } 1007 1008 void NVVM::WgmmaMmaAsyncOp::getAsmValues( 1009 RewriterBase &rewriter, 1010 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> 1011 &asmValues) { 1012 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; 1013 if (getResults()) 1014 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write}); 1015 if (getInouts()) 1016 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite}); 1017 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read}); 1018 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read}); 1019 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())), 1020 mlir::NVVM::PTXRegisterMod::Read}); 1021 if (getTypeD() != WGMMATypes::s32) { 1022 asmValues.push_back( 1023 {makeConstantI32(rewriter, 1024 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1), 1025 mlir::NVVM::PTXRegisterMod::Read}); 1026 asmValues.push_back( 1027 {makeConstantI32(rewriter, 1028 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1), 1029 mlir::NVVM::PTXRegisterMod::Read}); 1030 } 1031 if (isF16) { 1032 asmValues.push_back( 1033 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())), 1034 mlir::NVVM::PTXRegisterMod::Read}); 1035 asmValues.push_back( 1036 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())), 1037 mlir::NVVM::PTXRegisterMod::Read}); 1038 } 1039 } 1040 LogicalResult NVVM::FenceProxyOp::verify() { 1041 if (getKind() == NVVM::ProxyKind::TENSORMAP) 1042 return emitOpError() << "tensormap proxy is not a supported proxy kind"; 1043 if (getKind() == NVVM::ProxyKind::GENERIC) 1044 return emitOpError() << "generic proxy not a supported proxy kind"; 1045 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) { 1046 return emitOpError() << "async_shared fence requires space attribute"; 1047 } 1048 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) { 1049 return emitOpError() << "only async_shared fence can have space attribute"; 1050 } 1051 return success(); 1052 } 1053 1054 LogicalResult NVVM::FenceProxyAcquireOp::verify() { 1055 if (getFromProxy() != NVVM::ProxyKind::GENERIC) 1056 return emitOpError("uni-directional proxies only support generic for " 1057 "from_proxy attribute"); 1058 1059 if (getToProxy() != NVVM::ProxyKind::TENSORMAP) 1060 return emitOpError("uni-directional proxies only support tensormap " 1061 "for to_proxy attribute"); 1062 1063 return success(); 1064 } 1065 1066 LogicalResult NVVM::FenceProxyReleaseOp::verify() { 1067 if (getFromProxy() != NVVM::ProxyKind::GENERIC) 1068 return emitOpError("uni-directional proxies only support generic for " 1069 "from_proxy attribute"); 1070 1071 if (getToProxy() != NVVM::ProxyKind::TENSORMAP) 1072 return emitOpError("uni-directional proxies only support tensormap " 1073 "for to_proxy attribute"); 1074 1075 return success(); 1076 } 1077 1078 LogicalResult NVVM::SetMaxRegisterOp::verify() { 1079 if (getRegCount() % 8) 1080 return emitOpError("new register size must be multiple of 8"); 1081 if (getRegCount() < 24 || getRegCount() > 256) 1082 return emitOpError("new register size must be in between 24 to 256"); 1083 return success(); 1084 } 1085 1086 LogicalResult NVVM::BarrierOp::verify() { 1087 if (getNumberOfThreads() && !getBarrierId()) 1088 return emitOpError( 1089 "barrier id is missing, it should be set between 0 to 15"); 1090 return success(); 1091 } 1092 1093 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims, 1094 bool isIm2Col) { 1095 switch (tensorDims) { 1096 case 1: 1097 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d; 1098 case 2: 1099 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d; 1100 case 3: 1101 return isIm2Col 1102 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d 1103 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d; 1104 case 4: 1105 return isIm2Col 1106 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d 1107 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d; 1108 case 5: 1109 return isIm2Col 1110 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d 1111 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d; 1112 default: 1113 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp."); 1114 } 1115 } 1116 1117 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \ 1118 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d 1119 1120 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \ 1121 is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \ 1122 : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile) 1123 1124 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \ 1125 [&]() -> auto { \ 1126 switch (dims) { \ 1127 case 1: \ 1128 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \ 1129 case 2: \ 1130 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \ 1131 case 3: \ 1132 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \ 1133 case 4: \ 1134 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \ 1135 case 5: \ 1136 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \ 1137 default: \ 1138 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \ 1139 } \ 1140 }() 1141 1142 llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID( 1143 int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) { 1144 using RedTy = NVVM::TMAReduxKind; 1145 switch (kind) { 1146 case RedTy::ADD: 1147 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col); 1148 case RedTy::MIN: 1149 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col); 1150 case RedTy::MAX: 1151 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col); 1152 case RedTy::INC: 1153 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col); 1154 case RedTy::DEC: 1155 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col); 1156 case RedTy::AND: 1157 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col); 1158 case RedTy::OR: 1159 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col); 1160 case RedTy::XOR: 1161 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col); 1162 } 1163 llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp"); 1164 } 1165 1166 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might 1167 /// have ConstantRangeAttr. 1168 static void nvvmInferResultRanges(Operation *op, Value result, 1169 ArrayRef<::mlir::ConstantIntRanges> argRanges, 1170 SetIntRangeFn setResultRanges) { 1171 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) { 1172 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(), 1173 rangeAttr.getLower(), rangeAttr.getUpper()}); 1174 } 1175 } 1176 1177 //===----------------------------------------------------------------------===// 1178 // NVVMDialect initialization, type parsing, and registration. 1179 //===----------------------------------------------------------------------===// 1180 1181 // TODO: This should be the llvm.nvvm dialect once this is supported. 1182 void NVVMDialect::initialize() { 1183 addOperations< 1184 #define GET_OP_LIST 1185 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" 1186 >(); 1187 addAttributes< 1188 #define GET_ATTRDEF_LIST 1189 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" 1190 >(); 1191 1192 // Support unknown operations because not all NVVM operations are 1193 // registered. 1194 allowUnknownOperations(); 1195 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>(); 1196 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>(); 1197 } 1198 1199 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, 1200 NamedAttribute attr) { 1201 StringAttr attrName = attr.getName(); 1202 // Kernel function attribute should be attached to functions. 1203 if (attrName == NVVMDialect::getKernelFuncAttrName()) { 1204 if (!isa<LLVM::LLVMFuncOp>(op)) { 1205 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() 1206 << "' attribute attached to unexpected op"; 1207 } 1208 } 1209 // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3 1210 // dim 1211 if (attrName == NVVMDialect::getMaxntidAttrName() || 1212 attrName == NVVMDialect::getReqntidAttrName() || 1213 attrName == NVVMDialect::getClusterDimAttrName()) { 1214 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue()); 1215 if (!values || values.empty() || values.size() > 3) 1216 return op->emitError() 1217 << "'" << attrName 1218 << "' attribute must be integer array with maximum 3 index"; 1219 } 1220 // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer 1221 // attribute 1222 if (attrName == NVVMDialect::getMinctasmAttrName() || 1223 attrName == NVVMDialect::getMaxnregAttrName() || 1224 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) { 1225 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) 1226 return op->emitError() 1227 << "'" << attrName << "' attribute must be integer constant"; 1228 } 1229 1230 return success(); 1231 } 1232 1233 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op, 1234 unsigned regionIndex, 1235 unsigned argIndex, 1236 NamedAttribute argAttr) { 1237 auto funcOp = dyn_cast<FunctionOpInterface>(op); 1238 if (!funcOp) 1239 return success(); 1240 1241 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName()); 1242 StringAttr attrName = argAttr.getName(); 1243 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) { 1244 if (!isKernel) { 1245 return op->emitError() 1246 << "'" << attrName 1247 << "' attribute must be present only on kernel arguments"; 1248 } 1249 if (!isa<UnitAttr>(argAttr.getValue())) 1250 return op->emitError() << "'" << attrName << "' must be a unit attribute"; 1251 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) { 1252 return op->emitError() 1253 << "'" << attrName 1254 << "' attribute requires the argument to also have attribute '" 1255 << LLVM::LLVMDialect::getByValAttrName() << "'"; 1256 } 1257 } 1258 1259 return success(); 1260 } 1261 1262 //===----------------------------------------------------------------------===// 1263 // NVVM target attribute. 1264 //===----------------------------------------------------------------------===// 1265 LogicalResult 1266 NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1267 int optLevel, StringRef triple, StringRef chip, 1268 StringRef features, DictionaryAttr flags, 1269 ArrayAttr files) { 1270 if (optLevel < 0 || optLevel > 3) { 1271 emitError() << "The optimization level must be a number between 0 and 3."; 1272 return failure(); 1273 } 1274 if (triple.empty()) { 1275 emitError() << "The target triple cannot be empty."; 1276 return failure(); 1277 } 1278 if (chip.empty()) { 1279 emitError() << "The target chip cannot be empty."; 1280 return failure(); 1281 } 1282 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) { 1283 return attr && mlir::isa<StringAttr>(attr); 1284 })) { 1285 emitError() << "All the elements in the `link` array must be strings."; 1286 return failure(); 1287 } 1288 return success(); 1289 } 1290 1291 #define GET_OP_CLASSES 1292 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" 1293 1294 #define GET_ATTRDEF_CLASSES 1295 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" 1296