1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the NVVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 // The NVVM dialect only contains GPU specific additions on top of the general 13 // LLVM dialect. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 18 19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 20 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/IR/BuiltinAttributes.h" 24 #include "mlir/IR/BuiltinTypes.h" 25 #include "mlir/IR/Diagnostics.h" 26 #include "mlir/IR/DialectImplementation.h" 27 #include "mlir/IR/MLIRContext.h" 28 #include "mlir/IR/Operation.h" 29 #include "mlir/IR/OperationSupport.h" 30 #include "mlir/IR/Types.h" 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/AsmParser/Parser.h" 34 #include "llvm/IR/Attributes.h" 35 #include "llvm/IR/Function.h" 36 #include "llvm/IR/Type.h" 37 #include "llvm/Support/Casting.h" 38 #include "llvm/Support/SourceMgr.h" 39 #include "llvm/Support/raw_ostream.h" 40 #include <cassert> 41 #include <optional> 42 #include <string> 43 44 using namespace mlir; 45 using namespace NVVM; 46 47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc" 48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc" 49 50 //===----------------------------------------------------------------------===// 51 // Printing/parsing for NVVM ops 52 //===----------------------------------------------------------------------===// 53 54 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) { 55 p << " " << op->getOperands(); 56 if (op->getNumResults() > 0) 57 p << " : " << op->getResultTypes(); 58 } 59 60 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type 61 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) { 62 MLIRContext *context = parser.getContext(); 63 auto int32Ty = IntegerType::get(context, 32); 64 auto int1Ty = IntegerType::get(context, 1); 65 66 SmallVector<OpAsmParser::UnresolvedOperand, 8> ops; 67 Type type; 68 return failure(parser.parseOperandList(ops) || 69 parser.parseOptionalAttrDict(result.attributes) || 70 parser.parseColonType(type) || 71 parser.addTypeToList(type, result.types) || 72 parser.resolveOperands(ops, {int32Ty, int1Ty}, 73 parser.getNameLoc(), result.operands)); 74 } 75 76 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); } 77 78 // This verifier is shared among the following Ops: 79 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) 80 // CpAsyncBulkTensorPrefetchOp (TMA Prefetch) 81 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce) 82 static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims, 83 bool isIm2Col, 84 size_t numIm2ColOffsets, 85 Location loc) { 86 if (tensorDims < 1 || tensorDims > 5) 87 return emitError(loc, "expects coordinates between 1 to 5 dimension"); 88 89 // For Im2Col mode, there are two constraints: 90 if (isIm2Col) { 91 // 1. Tensor must always be at least 3-d. 92 if (tensorDims < 3) 93 return emitError( 94 loc, 95 "to use im2col mode, the tensor has to be at least 3-dimensional"); 96 // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number. 97 if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2))) 98 return emitError( 99 loc, "im2col offsets must be 2 less than number of coordinates"); 100 } 101 return success(); 102 } 103 104 LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() { 105 size_t numIm2ColOffsets = getIm2colOffsets().size(); 106 bool isIm2Col = numIm2ColOffsets > 0; 107 return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 108 numIm2ColOffsets, getLoc()); 109 } 110 111 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() { 112 if (getCoordinates().size() > 5) 113 return emitError("Maximum 5 coordinates and dimension is supported."); 114 return success(); 115 } 116 117 LogicalResult CpAsyncOp::verify() { 118 if (getModifier() != LoadCacheModifierKind::CG && 119 getModifier() != LoadCacheModifierKind::CA) 120 return emitError("Only CG and CA cache modifiers are supported."); 121 if (getSize() != 4 && getSize() != 8 && getSize() != 16) 122 return emitError("expected byte size to be either 4, 8 or 16."); 123 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16) 124 return emitError("CG cache modifier is only support for 16 bytes copy."); 125 return success(); 126 } 127 128 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() { 129 size_t numIm2ColOffsets = getIm2colOffsets().size(); 130 bool isIm2Col = numIm2ColOffsets > 0; 131 return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 132 numIm2ColOffsets, getLoc()); 133 } 134 135 LogicalResult CpAsyncBulkTensorReduceOp::verify() { 136 bool isIm2Col = (getMode() == TMAStoreMode::IM2COL); 137 return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0, 138 getLoc()); 139 } 140 141 LogicalResult CvtFloatToTF32Op::verify() { 142 using RndMode = NVVM::FPRoundingMode; 143 switch (getRnd()) { 144 case RndMode::RNA: 145 if (getRelu()) 146 return emitError("Relu not supported with rna rounding mode."); 147 break; 148 case RndMode::RN: 149 case RndMode::RZ: 150 if (getSat() != NVVM::SaturationMode::NONE) 151 return emitError( 152 "Saturation mode not supported with rn/rz rounding modes."); 153 break; 154 default: 155 return emitError( 156 "Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op."); 157 } 158 return success(); 159 } 160 161 // Given the element type of an operand and whether or not it is an accumulator, 162 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the 163 // operand's element type. 164 std::optional<mlir::NVVM::MMATypes> 165 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) { 166 auto half2Type = 167 LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2); 168 if (operandElType.isF64()) 169 return NVVM::MMATypes::f64; 170 if (operandElType.isF16() || operandElType == half2Type) 171 return NVVM::MMATypes::f16; 172 if (operandElType.isF32() && isAccumulator) 173 return NVVM::MMATypes::f32; 174 if (operandElType.isF32() && !isAccumulator) 175 return NVVM::MMATypes::tf32; 176 if (llvm::isa<IntegerType>(operandElType)) { 177 if (isAccumulator) 178 return NVVM::MMATypes::s32; 179 return std::nullopt; 180 } 181 182 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) { 183 if (structType.getBody().empty()) 184 return std::nullopt; 185 return inferOperandMMAType(structType.getBody()[0], isAccumulator); 186 } 187 188 return std::nullopt; 189 } 190 191 static bool isInt4PtxType(MMATypes type) { 192 return (type == MMATypes::u4 || type == MMATypes::s4); 193 } 194 195 static bool isInt8PtxType(MMATypes type) { 196 return (type == MMATypes::u8 || type == MMATypes::s8); 197 } 198 199 static bool isIntegerPtxType(MMATypes type) { 200 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 || 201 type == MMATypes::s32; 202 } 203 204 MMATypes MmaOp::accumPtxType() { 205 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType( 206 getODSOperands(2).getTypes().front(), /*isAccum=*/true); 207 assert(val.has_value() && "accumulator PTX type should always be inferrable"); 208 return val.value(); 209 } 210 211 MMATypes MmaOp::resultPtxType() { 212 std::optional<mlir::NVVM::MMATypes> val = 213 inferOperandMMAType(getResult().getType(), /*isAccum=*/true); 214 assert(val.has_value() && "result PTX type should always be inferrable"); 215 return val.value(); 216 } 217 218 void MmaOp::print(OpAsmPrinter &p) { 219 SmallVector<Type, 4> regTypes; 220 struct OperandFragment { 221 StringRef operandName; 222 StringRef ptxTypeAttr; 223 SmallVector<Value, 4> regs; 224 explicit OperandFragment(StringRef name, StringRef ptxTypeName) 225 : operandName(name), ptxTypeAttr(ptxTypeName) {} 226 }; 227 228 std::array<OperandFragment, 3> frags{ 229 OperandFragment("A", getMultiplicandAPtxTypeAttrName()), 230 OperandFragment("B", getMultiplicandBPtxTypeAttrName()), 231 OperandFragment("C", "")}; 232 SmallVector<StringRef, 4> ignoreAttrNames{ 233 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()}; 234 235 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) { 236 auto &frag = frags[fragIdx]; 237 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx); 238 for (auto operandIdx = varOperandSpec.first; 239 operandIdx < varOperandSpec.first + varOperandSpec.second; 240 operandIdx++) { 241 frag.regs.push_back(this->getOperand(operandIdx)); 242 if (operandIdx == 0) { 243 regTypes.push_back(this->getOperand(operandIdx).getType()); 244 } 245 } 246 std::optional<MMATypes> inferredType = 247 inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2); 248 if (inferredType) 249 ignoreAttrNames.push_back(frag.ptxTypeAttr); 250 } 251 252 auto printMmaOperand = [&](const OperandFragment &frag) -> void { 253 p << " " << frag.operandName; 254 p << "["; 255 p.printOperands(frag.regs); 256 p << "] "; 257 }; 258 259 for (const auto &frag : frags) { 260 printMmaOperand(frag); 261 } 262 263 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); 264 265 // Print the types of the operands and result. 266 p << " : " << "("; 267 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(), 268 frags[1].regs[0].getType(), 269 frags[2].regs[0].getType()}, 270 p); 271 p << ")"; 272 p.printArrowTypeList(TypeRange{this->getRes().getType()}); 273 } 274 275 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType, 276 ValueRange operandA, ValueRange operandB, ValueRange operandC, 277 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op, 278 std::optional<MMAIntOverflow> intOverflow, 279 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes, 280 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) { 281 282 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); 283 MLIRContext *ctx = builder.getContext(); 284 result.addAttribute( 285 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2])); 286 287 result.addOperands(operandA); 288 result.addOperands(operandB); 289 result.addOperands(operandC); 290 291 if (multiplicandPtxTypes) { 292 result.addAttribute("multiplicandAPtxType", 293 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); 294 result.addAttribute("multiplicandBPtxType", 295 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); 296 } else { 297 if (auto res = inferOperandMMAType(operandA[0].getType(), false)) 298 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); 299 if (auto res = inferOperandMMAType(operandB[0].getType(), false)) 300 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); 301 } 302 303 if (multiplicandLayouts) { 304 result.addAttribute("layoutA", 305 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0])); 306 result.addAttribute("layoutB", 307 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1])); 308 } else { 309 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row)); 310 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col)); 311 } 312 313 if (intOverflow.has_value()) 314 result.addAttribute("intOverflowBehavior", 315 MMAIntOverflowAttr::get(ctx, *intOverflow)); 316 if (b1Op.has_value()) 317 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op)); 318 319 result.addTypes(resultType); 320 result.addAttribute( 321 MmaOp::getOperandSegmentSizeAttr(), 322 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()), 323 static_cast<int32_t>(operandB.size()), 324 static_cast<int32_t>(operandC.size())})); 325 } 326 327 // <operation> := 328 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]` 329 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0])) 330 // `->` type($res) 331 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) { 332 struct OperandFragment { 333 std::optional<MMATypes> elemtype; 334 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs; 335 SmallVector<Type> regTypes; 336 }; 337 338 Builder &builder = parser.getBuilder(); 339 std::array<OperandFragment, 4> frags; 340 341 NamedAttrList namedAttributes; 342 343 // A helper to parse the operand segments. 344 auto parseMmaOperand = [&](StringRef operandName, 345 OperandFragment &frag) -> LogicalResult { 346 if (parser.parseKeyword(operandName).failed()) 347 return failure(); 348 if (parser 349 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare) 350 .failed()) 351 return failure(); 352 return success(); 353 }; 354 355 // Parse the operand segments. 356 if (parseMmaOperand("A", frags[0]).failed()) 357 return failure(); 358 if (parseMmaOperand("B", frags[1]).failed()) 359 return failure(); 360 if (parseMmaOperand("C", frags[2]).failed()) 361 return failure(); 362 363 if (parser.parseOptionalAttrDict(namedAttributes).failed()) 364 return failure(); 365 366 // Parse the type specification and resolve operands. 367 SmallVector<Type, 3> operandTypes; 368 if (failed(parser.parseColon())) 369 return failure(); 370 if (failed(parser.parseLParen())) 371 return failure(); 372 if (failed(parser.parseTypeList(operandTypes))) 373 return failure(); 374 if (failed(parser.parseRParen())) 375 if (operandTypes.size() != 3) 376 return parser.emitError( 377 parser.getNameLoc(), 378 "expected one type for each operand segment but got " + 379 Twine(operandTypes.size()) + " types"); 380 for (const auto &iter : llvm::enumerate(operandTypes)) { 381 auto &frag = frags[iter.index()]; 382 frag.regTypes.resize(frag.regs.size(), iter.value()); 383 if (failed(parser.resolveOperands(frag.regs, frag.regTypes, 384 parser.getNameLoc(), result.operands))) 385 return failure(); 386 frag.elemtype = 387 inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2); 388 } 389 390 Type resultType; 391 if (parser.parseArrow() || parser.parseType(resultType)) 392 return failure(); 393 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true); 394 395 std::array<StringRef, 2> names{"multiplicandAPtxType", 396 "multiplicandBPtxType"}; 397 for (unsigned idx = 0; idx < names.size(); idx++) { 398 const auto &frag = frags[idx]; 399 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]); 400 if (!frag.elemtype.has_value() && !attr.has_value()) { 401 return parser.emitError( 402 parser.getNameLoc(), 403 "attribute " + names[idx] + 404 " is not provided explicitly and cannot be inferred"); 405 } 406 if (!attr.has_value()) 407 result.addAttribute( 408 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype)); 409 } 410 411 result.addTypes(resultType); 412 if (!namedAttributes.empty()) 413 result.addAttributes(namedAttributes); 414 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(), 415 builder.getDenseI32ArrayAttr({ 416 static_cast<int32_t>(frags[0].regs.size()), 417 static_cast<int32_t>(frags[1].regs.size()), 418 static_cast<int32_t>(frags[2].regs.size()), 419 })); 420 return success(); 421 } 422 423 LogicalResult MmaOp::verify() { 424 MLIRContext *context = getContext(); 425 auto f16Ty = Float16Type::get(context); 426 auto i32Ty = IntegerType::get(context, 32); 427 auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2); 428 auto f32Ty = Float32Type::get(context); 429 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( 430 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); 431 432 auto s32x4StructTy = 433 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty}); 434 auto f32x8StructTy = 435 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty)); 436 auto f16x2x2StructTy = 437 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty}); 438 auto f32x4StructTy = 439 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty}); 440 auto s32x2StructTy = 441 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); 442 443 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(), 444 getShapeAttr().getK()}; 445 446 // These variables define the set of allowed data types for matrices A, B, C, 447 // and result. 448 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>; 449 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>; 450 AllowedShapes allowedShapes; 451 AllowedTypes expectedA; 452 AllowedTypes expectedB; 453 AllowedTypes expectedC; 454 SmallVector<Type> expectedResult; 455 456 // When M = 16, we just need to calculate the number of 8xk tiles, where 457 // k is a factor that depends on the data type. 458 if (mmaShape[0] == 16) { 459 int64_t kFactor; 460 Type multiplicandFragType; 461 switch (*getMultiplicandAPtxType()) { 462 case MMATypes::tf32: 463 kFactor = 4; 464 multiplicandFragType = i32Ty; 465 expectedResult.push_back(LLVM::LLVMStructType::getLiteral( 466 context, {f32Ty, f32Ty, f32Ty, f32Ty})); 467 break; 468 case MMATypes::bf16: 469 kFactor = 8; 470 multiplicandFragType = i32Ty; 471 expectedResult.push_back(LLVM::LLVMStructType::getLiteral( 472 context, {f32Ty, f32Ty, f32Ty, f32Ty})); 473 break; 474 case MMATypes::f16: 475 kFactor = 8; 476 multiplicandFragType = f16x2Ty; 477 expectedResult.push_back(f16x2x2StructTy); 478 expectedResult.push_back(f32x4StructTy); 479 break; 480 case MMATypes::s4: 481 case MMATypes::u4: 482 kFactor = 32; 483 break; 484 case MMATypes::b1: 485 kFactor = 128; 486 break; 487 case MMATypes::s8: 488 case MMATypes::u8: 489 kFactor = 16; 490 break; 491 default: 492 return emitError("invalid shape or multiplicand type: " + 493 stringifyEnum(getMultiplicandAPtxType().value())); 494 } 495 496 if (isIntegerPtxType(getMultiplicandAPtxType().value())) { 497 expectedResult.push_back(s32x4StructTy); 498 expectedC.emplace_back(4, i32Ty); 499 multiplicandFragType = i32Ty; 500 } else { 501 expectedC.emplace_back(2, f16x2Ty); 502 expectedC.emplace_back(4, f32Ty); 503 } 504 505 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor); 506 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor); 507 expectedA.emplace_back(unitA, multiplicandFragType); 508 expectedB.emplace_back(unitB, multiplicandFragType); 509 allowedShapes.push_back({16, 8, kFactor}); 510 allowedShapes.push_back({16, 8, kFactor * 2}); 511 } 512 513 // In the M=8 case, there is only 1 possible case per data type. 514 if (mmaShape[0] == 8) { 515 if (*getMultiplicandAPtxType() == MMATypes::f16) { 516 expectedA.emplace_back(2, f16x2Ty); 517 expectedB.emplace_back(2, f16x2Ty); 518 expectedResult.push_back(f16x2x4StructTy); 519 expectedResult.push_back(f32x8StructTy); 520 expectedC.emplace_back(4, f16x2Ty); 521 expectedC.emplace_back(8, f32Ty); 522 allowedShapes.push_back({8, 8, 4}); 523 } 524 if (*getMultiplicandAPtxType() == MMATypes::f64) { 525 Type f64Ty = Float64Type::get(context); 526 expectedA.emplace_back(1, f64Ty); 527 expectedB.emplace_back(1, f64Ty); 528 expectedC.emplace_back(2, f64Ty); 529 // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2)); 530 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( 531 context, SmallVector<Type>(2, f64Ty))); 532 allowedShapes.push_back({8, 8, 4}); 533 } 534 if (isIntegerPtxType(getMultiplicandAPtxType().value())) { 535 expectedA.push_back({i32Ty}); 536 expectedB.push_back({i32Ty}); 537 expectedC.push_back({i32Ty, i32Ty}); 538 expectedResult.push_back(s32x2StructTy); 539 if (isInt4PtxType(getMultiplicandAPtxType().value())) 540 allowedShapes.push_back({8, 8, 32}); 541 if (isInt8PtxType(getMultiplicandAPtxType().value())) 542 allowedShapes.push_back({8, 8, 16}); 543 if (getMultiplicandAPtxType().value() == MMATypes::b1) 544 allowedShapes.push_back({8, 8, 128}); 545 } 546 } 547 548 std::string errorMessage; 549 llvm::raw_string_ostream errorStream(errorMessage); 550 551 // Check that we matched an existing shape/dtype combination. 552 if (expectedA.empty() || expectedB.empty() || expectedC.empty() || 553 !llvm::is_contained(allowedShapes, mmaShape)) { 554 errorStream << "unimplemented variant for MMA shape <"; 555 llvm::interleaveComma(mmaShape, errorStream); 556 errorStream << ">"; 557 return emitOpError(errorMessage); 558 } 559 560 // Verify the operand types for segments of A, B, and C operands. 561 std::array<StringRef, 3> operandNames{"A", "B", "C"}; 562 for (const auto &iter : llvm::enumerate( 563 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) { 564 auto spec = this->getODSOperandIndexAndLength(iter.index()); 565 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first, 566 operand_type_begin() + spec.first + 567 spec.second); 568 bool match = llvm::is_contained(iter.value(), operandTySeg); 569 570 if (!match) { 571 errorStream << "Could not match types for the " 572 << operandNames[iter.index()] 573 << " operands; expected one of "; 574 for (const auto &x : iter.value()) { 575 errorStream << x.size() << "x" << x[0] << " "; 576 } 577 errorStream << "but got "; 578 llvm::interleaveComma(operandTySeg, errorStream); 579 return emitOpError(errorMessage); 580 } 581 } 582 583 // Check the result type 584 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) { 585 return expectedResultType == getResult().getType(); 586 })) { 587 errorStream 588 << "Could not match allowed types for the result; expected one of "; 589 llvm::interleaveComma(expectedResult, errorStream); 590 errorStream << " but got " << getResult().getType(); 591 return emitOpError(errorMessage); 592 } 593 594 // Ensure that binary MMA variants have a b1 MMA operation defined. 595 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) { 596 return emitOpError("op requires " + getB1OpAttrName().strref() + 597 " attribute"); 598 } 599 600 // Ensure int4/int8 MMA variants specify the accum overflow behavior 601 // attribute. 602 if (isInt4PtxType(*getMultiplicandAPtxType()) || 603 isInt8PtxType(*getMultiplicandAPtxType())) { 604 if (!getIntOverflowBehavior()) 605 return emitOpError("op requires " + 606 getIntOverflowBehaviorAttrName().strref() + 607 " attribute"); 608 } 609 610 return success(); 611 } 612 613 LogicalResult ShflOp::verify() { 614 if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid")) 615 return success(); 616 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); 617 auto elementType = (type && type.getBody().size() == 2) 618 ? llvm::dyn_cast<IntegerType>(type.getBody()[1]) 619 : nullptr; 620 if (!elementType || elementType.getWidth() != 1) 621 return emitError("expected return type to be a two-element struct with " 622 "i1 as the second element"); 623 return success(); 624 } 625 626 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type, 627 NVVM::MMAFrag frag, int nRow, 628 int nCol, 629 MLIRContext *context) { 630 unsigned numberElements = 0; 631 Type elementType; 632 OpBuilder builder(context); 633 Type f16x2 = VectorType::get(2, builder.getF16Type()); 634 if (type == NVVM::MMATypes::f16) { 635 elementType = f16x2; 636 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) 637 numberElements = 8; 638 else 639 numberElements = 4; 640 } else if (type == NVVM::MMATypes::f32) { 641 elementType = builder.getF32Type(); 642 numberElements = 8; 643 } else if (type == NVVM::MMATypes::tf32) { 644 elementType = builder.getI32Type(); 645 numberElements = 4; 646 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) { 647 elementType = builder.getI32Type(); 648 int parallelSize = 0; 649 if (frag == NVVM::MMAFrag::a) 650 parallelSize = nRow; 651 if (frag == NVVM::MMAFrag::b) 652 parallelSize = nCol; 653 654 // m == 16 && n == 16 && k == 16 655 if (parallelSize == 16) 656 numberElements = 2; 657 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16 658 else if (parallelSize == 8) 659 numberElements = 1; 660 else if (parallelSize == 32) 661 numberElements = 4; 662 } else if (type == NVVM::MMATypes::s32) { 663 elementType = builder.getI32Type(); 664 numberElements = 8; 665 } 666 assert(numberElements != 0 && elementType != nullptr); 667 return std::make_pair(elementType, numberElements); 668 } 669 670 static std::pair<mlir::Type, unsigned> 671 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, 672 int k, MLIRContext *context) { 673 int nRow, nCol; 674 if (frag == NVVM::MMAFrag::a) { 675 nRow = m; 676 nCol = k; 677 } else if (frag == NVVM::MMAFrag::b) { 678 nRow = k; 679 nCol = n; 680 } else { 681 nRow = m; 682 nCol = n; 683 } 684 assert(nRow && nCol); 685 return inferMMAType(type, frag, nRow, nCol, context); 686 } 687 688 LogicalResult NVVM::WMMALoadOp::verify() { 689 unsigned addressSpace = 690 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 691 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace && 692 addressSpace != NVVM::kSharedMemorySpace) 693 return emitOpError("expected source pointer in memory " 694 "space 0, 1, 3"); 695 696 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), 697 getEltype(), getFrag()) == 0) 698 return emitOpError() << "invalid attribute combination"; 699 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( 700 getEltype(), getFrag(), getM(), getN(), getK(), getContext()); 701 Type dstType = LLVM::LLVMStructType::getLiteral( 702 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); 703 if (getType() != dstType) 704 return emitOpError("expected destination type is a structure of ") 705 << typeInfo.second << " elements of type " << typeInfo.first; 706 return success(); 707 } 708 709 LogicalResult NVVM::WMMAStoreOp::verify() { 710 unsigned addressSpace = 711 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 712 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace && 713 addressSpace != NVVM::kSharedMemorySpace) 714 return emitOpError("expected operands to be a source pointer in memory " 715 "space 0, 1, 3"); 716 717 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), 718 getEltype()) == 0) 719 return emitOpError() << "invalid attribute combination"; 720 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( 721 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); 722 if (getArgs().size() != typeInfo.second) 723 return emitOpError() << "expected " << typeInfo.second << " data operands"; 724 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) { 725 return operands.getType() != typeInfo.first; 726 })) 727 return emitOpError() << "expected data operands of type " << typeInfo.first; 728 return success(); 729 } 730 731 LogicalResult NVVM::WMMAMmaOp::verify() { 732 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(), 733 getLayoutB(), getEltypeA(), 734 getEltypeB()) == 0) 735 return emitOpError() << "invalid attribute combination"; 736 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK( 737 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext()); 738 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK( 739 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext()); 740 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK( 741 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); 742 SmallVector<Type, 32> arguments; 743 arguments.append(typeInfoA.second, typeInfoA.first); 744 arguments.append(typeInfoB.second, typeInfoB.first); 745 arguments.append(typeInfoC.second, typeInfoC.first); 746 unsigned numArgs = arguments.size(); 747 if (getArgs().size() != numArgs) 748 return emitOpError() << "expected " << numArgs << " arguments"; 749 for (unsigned i = 0; i < numArgs; i++) { 750 if (getArgs()[i].getType() != arguments[i]) 751 return emitOpError() << "expected argument " << i << " to be of type " 752 << arguments[i]; 753 } 754 Type dstType = LLVM::LLVMStructType::getLiteral( 755 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first)); 756 if (getType() != dstType) 757 return emitOpError("expected destination type is a structure of ") 758 << typeInfoC.second << " elements of type " << typeInfoC.first; 759 return success(); 760 } 761 762 LogicalResult NVVM::LdMatrixOp::verify() { 763 unsigned addressSpace = 764 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 765 if (addressSpace != NVVM::kSharedMemorySpace) 766 return emitOpError("expected source pointer in memory space 3"); 767 768 if (getNum() != 1 && getNum() != 2 && getNum() != 4) 769 return emitOpError("expected num attribute to be 1, 2 or 4"); 770 771 Type i32 = IntegerType::get(getContext(), 32); 772 if (getNum() == 1 && getType() != i32) 773 return emitOpError("expected destination type is i32"); 774 if (getNum() == 2 || getNum() == 4) { 775 Type dstType = LLVM::LLVMStructType::getLiteral( 776 getContext(), SmallVector<Type>(getNum(), i32)); 777 if (getType() != dstType) 778 return emitOpError("expected destination type is a structure of ") 779 << getNum() << " elements of type i32"; 780 } 781 return success(); 782 } 783 784 LogicalResult NVVM::StMatrixOp::verify() { 785 unsigned addressSpace = 786 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 787 if (addressSpace != NVVM::kSharedMemorySpace) 788 return emitOpError("expected source pointer in memory space 3"); 789 790 int numMatrix = getSources().size(); 791 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4) 792 return emitOpError("expected num attribute to be 1, 2 or 4"); 793 794 return success(); 795 } 796 797 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) { 798 if (typeA == NVVM::WGMMATypes::tf32) 799 return 8; 800 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16) 801 return 16; 802 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8) 803 return 32; 804 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2) 805 return 32; 806 if (typeA == NVVM::WGMMATypes::b1) 807 return 256; 808 return failure(); 809 } 810 811 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, 812 NVVM::WGMMATypes typeA, 813 NVVM::WGMMATypes typeB) { 814 switch (typeA) { 815 case NVVM::WGMMATypes::f16: 816 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 817 typeB == NVVM::WGMMATypes::f16) 818 return success(); 819 break; 820 case NVVM::WGMMATypes::tf32: 821 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32) 822 return success(); 823 break; 824 case NVVM::WGMMATypes::u8: 825 case NVVM::WGMMATypes::s8: 826 if (typeD == NVVM::WGMMATypes::s32 && 827 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8)) 828 return success(); 829 break; 830 case NVVM::WGMMATypes::b1: 831 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1) 832 return success(); 833 break; 834 case NVVM::WGMMATypes::bf16: 835 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 836 typeB == NVVM::WGMMATypes::bf16) 837 return success(); 838 break; 839 case NVVM::WGMMATypes::e4m3: 840 case NVVM::WGMMATypes::e5m2: 841 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 842 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3)) 843 return success(); 844 break; 845 case WGMMATypes::f32: 846 case WGMMATypes::s32: 847 llvm_unreachable("unsupported input types"); 848 break; 849 } 850 return failure(); 851 } 852 853 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) { 854 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64, 855 72, 80, 88, 96, 104, 112, 120, 128, 856 136, 144, 152, 160, 168, 176, 184, 192, 857 200, 208, 216, 224, 232, 240, 248, 256}; 858 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64, 859 80, 96, 112, 128, 144, 160, 860 176, 192, 208, 224, 240, 256}; 861 switch (typeA) { 862 case WGMMATypes::f16: 863 case WGMMATypes::tf32: 864 case WGMMATypes::bf16: 865 case WGMMATypes::e4m3: 866 case WGMMATypes::e5m2: 867 if (llvm::is_contained(allowedN, sizeN)) 868 return success(); 869 break; 870 case WGMMATypes::u8: 871 case WGMMATypes::s8: 872 case WGMMATypes::b1: 873 if (llvm::is_contained(allowedNshort, sizeN)) 874 return success(); 875 break; 876 case WGMMATypes::f32: 877 case WGMMATypes::s32: 878 llvm_unreachable("unsupported input types"); 879 break; 880 } 881 return failure(); 882 } 883 884 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() { 885 Value outValue = getResults(); 886 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType()); 887 if (!stype) 888 return emitOpError() << "expected results to be struct"; 889 int outputSize = stype.getBody().size(); 890 WGMMATypes typeD = getTypeD(); 891 WGMMATypes typeA = getTypeA(); 892 WGMMATypes typeB = getTypeB(); 893 894 for (Type t : stype.getBody()) { 895 if (t != stype.getBody().front()) 896 return emitOpError() 897 << "all elements in struct must be same type but there is " << t; 898 } 899 900 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 && 901 typeD != WGMMATypes::s32) { 902 return emitOpError() << "does not support the given output type " 903 << NVVM::stringifyWGMMATypes(typeD); 904 } 905 if (typeD == WGMMATypes::s32 && 906 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) { 907 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg"; 908 } 909 910 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) { 911 return emitOpError() << NVVM::stringifyWGMMATypes(typeD) 912 << " += " << NVVM::stringifyWGMMATypes(typeA) << " * " 913 << NVVM::stringifyWGMMATypes(typeB) 914 << ", it is not supported."; 915 } 916 917 // Check M 918 if (getShape().getM() != 64) 919 return emitOpError() << "shape 'm' must be 64"; 920 921 // Check K 922 FailureOr<int> allowedK = getAllowedSizeK(typeA); 923 if (failed(allowedK) || allowedK.value() != getShape().getK()) 924 return emitOpError() << "shape 'k' must be " << allowedK.value() 925 << " for input type " 926 << NVVM::stringifyWGMMATypes(typeA); 927 928 // Check N 929 if (failed(isAllowedSizeN(getShape().getN(), typeA))) { 930 return emitOpError() << "has input type " 931 << NVVM::stringifyWGMMATypes(typeA) << " n is set to " 932 << getShape().getN() << ", it is not supported."; 933 } 934 935 // Check transpose (only available for f16/bf16) 936 // Matrices A should be stored in row-major and B in column-major. 937 // Only f16/bf16 matrices can be stored in either column-major or row-major 938 // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code. 939 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) && 940 (getLayoutA() == mlir::NVVM::MMALayout::col || 941 getLayoutB() == mlir::NVVM::MMALayout::row)) { 942 return emitOpError() 943 << "given layouts layout_a = " << stringifyMMALayout(getLayoutA()) 944 << " and layout_b = " << stringifyMMALayout(getLayoutB()) 945 << " for input types " << stringifyWGMMATypes(typeA) << " and " 946 << stringifyWGMMATypes(typeB) 947 << " requires transpose. However, this is only supported for: " 948 << stringifyMMATypes(MMATypes::f16) << " and " 949 << stringifyMMATypes(MMATypes::bf16); 950 } 951 952 // Check result registers 953 int expectedOutput = 0; 954 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32) 955 expectedOutput = getShape().getN() / 2; 956 if (typeD == WGMMATypes::f16) 957 expectedOutput = getShape().getN() / 4; 958 if (outputSize != expectedOutput) { 959 return emitOpError() << "results " << expectedOutput 960 << ", however output struct has " << outputSize 961 << " elements"; 962 } 963 // Check satfinite (only available for s32 accumulator) 964 if (typeD != WGMMATypes::s32 && 965 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == 966 NVVM::MMAIntOverflow::satfinite) { 967 return emitOpError() 968 << " `satfinite` can be only used with s32 accumulator, however " 969 "the current accumulator is " 970 << NVVM::stringifyWGMMATypes(typeD); 971 } 972 973 return success(); 974 } 975 976 std::string NVVM::WgmmaMmaAsyncOp::getPtx() { 977 978 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK(); 979 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; 980 981 StringRef outputTypeName = stringifyWGMMATypes(getTypeD()); 982 983 int expectedOutputRegisters = 0; 984 if (getTypeD() == WGMMATypes::f16) 985 expectedOutputRegisters = getShape().getN() / 4; 986 else 987 expectedOutputRegisters = getShape().getN() / 2; 988 989 std::string ptx; 990 llvm::raw_string_ostream ss(ptx); 991 992 ss << "{\n" 993 ".reg .pred p;\n" 994 "setp.ne.b32 p, $" 995 << ((expectedOutputRegisters * 2) + 2) 996 << ", 0;\n" 997 "wgmma.mma_async.sync.aligned.m" 998 << m << "n" << n << "k" << k << "." << outputTypeName << "." 999 << stringifyWGMMATypes(getTypeA()) << "." 1000 << stringifyWGMMATypes(getTypeB()); 1001 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == 1002 NVVM::MMAIntOverflow::satfinite) 1003 ss << ".satfinite"; 1004 ss << " {"; 1005 int regCnt = 0; 1006 for (; regCnt < expectedOutputRegisters; ++regCnt) { 1007 ss << "$" << regCnt; 1008 if (regCnt != expectedOutputRegisters - 1) 1009 ss << ", "; 1010 } 1011 1012 ss << "},"; 1013 // Need to map read/write registers correctly. 1014 regCnt = (regCnt * 2); 1015 ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p"; 1016 if (getTypeD() != WGMMATypes::s32) { 1017 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4); 1018 } 1019 // Don't add transpose parameters unless needed. 1020 if (isF16) { 1021 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6); 1022 } 1023 ss << ";\n" 1024 << "}\n"; 1025 return ptx; 1026 } 1027 1028 void NVVM::WgmmaMmaAsyncOp::getAsmValues( 1029 RewriterBase &rewriter, 1030 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> 1031 &asmValues) { 1032 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; 1033 if (getResults()) 1034 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write}); 1035 if (getInouts()) 1036 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite}); 1037 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read}); 1038 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read}); 1039 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())), 1040 mlir::NVVM::PTXRegisterMod::Read}); 1041 if (getTypeD() != WGMMATypes::s32) { 1042 asmValues.push_back( 1043 {makeConstantI32(rewriter, 1044 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1), 1045 mlir::NVVM::PTXRegisterMod::Read}); 1046 asmValues.push_back( 1047 {makeConstantI32(rewriter, 1048 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1), 1049 mlir::NVVM::PTXRegisterMod::Read}); 1050 } 1051 if (isF16) { 1052 asmValues.push_back( 1053 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())), 1054 mlir::NVVM::PTXRegisterMod::Read}); 1055 asmValues.push_back( 1056 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())), 1057 mlir::NVVM::PTXRegisterMod::Read}); 1058 } 1059 } 1060 LogicalResult NVVM::FenceProxyOp::verify() { 1061 if (getKind() == NVVM::ProxyKind::TENSORMAP) 1062 return emitOpError() << "tensormap proxy is not a supported proxy kind"; 1063 if (getKind() == NVVM::ProxyKind::GENERIC) 1064 return emitOpError() << "generic proxy not a supported proxy kind"; 1065 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) { 1066 return emitOpError() << "async_shared fence requires space attribute"; 1067 } 1068 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) { 1069 return emitOpError() << "only async_shared fence can have space attribute"; 1070 } 1071 return success(); 1072 } 1073 1074 LogicalResult NVVM::FenceProxyAcquireOp::verify() { 1075 if (getFromProxy() != NVVM::ProxyKind::GENERIC) 1076 return emitOpError("uni-directional proxies only support generic for " 1077 "from_proxy attribute"); 1078 1079 if (getToProxy() != NVVM::ProxyKind::TENSORMAP) 1080 return emitOpError("uni-directional proxies only support tensormap " 1081 "for to_proxy attribute"); 1082 1083 return success(); 1084 } 1085 1086 LogicalResult NVVM::FenceProxyReleaseOp::verify() { 1087 if (getFromProxy() != NVVM::ProxyKind::GENERIC) 1088 return emitOpError("uni-directional proxies only support generic for " 1089 "from_proxy attribute"); 1090 1091 if (getToProxy() != NVVM::ProxyKind::TENSORMAP) 1092 return emitOpError("uni-directional proxies only support tensormap " 1093 "for to_proxy attribute"); 1094 1095 return success(); 1096 } 1097 1098 LogicalResult NVVM::SetMaxRegisterOp::verify() { 1099 if (getRegCount() % 8) 1100 return emitOpError("new register size must be multiple of 8"); 1101 if (getRegCount() < 24 || getRegCount() > 256) 1102 return emitOpError("new register size must be in between 24 to 256"); 1103 return success(); 1104 } 1105 1106 LogicalResult NVVM::BarrierOp::verify() { 1107 if (getNumberOfThreads() && !getBarrierId()) 1108 return emitOpError( 1109 "barrier id is missing, it should be set between 0 to 15"); 1110 return success(); 1111 } 1112 1113 #define CP_ASYNC_ID_IMPL(mod, size, suffix) \ 1114 llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix 1115 1116 #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \ 1117 has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, ) 1118 1119 llvm::Intrinsic::ID 1120 CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, 1121 llvm::SmallVector<llvm::Value *> &args) { 1122 llvm::Intrinsic::ID id; 1123 1124 auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op); 1125 bool hasCpSize = cpAsyncOp.getCpSize() ? true : false; 1126 switch (cpAsyncOp.getSize()) { 1127 case 4: 1128 id = GET_CP_ASYNC_ID(ca, 4, hasCpSize); 1129 break; 1130 case 8: 1131 id = GET_CP_ASYNC_ID(ca, 8, hasCpSize); 1132 break; 1133 case 16: 1134 id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG) 1135 ? GET_CP_ASYNC_ID(cg, 16, hasCpSize) 1136 : GET_CP_ASYNC_ID(ca, 16, hasCpSize); 1137 break; 1138 default: 1139 llvm_unreachable("Invalid copy size in CpAsyncOp."); 1140 } 1141 1142 // Fill the Intrinsic Args 1143 args.push_back(mt.lookupValue(cpAsyncOp.getDst())); 1144 args.push_back(mt.lookupValue(cpAsyncOp.getSrc())); 1145 if (hasCpSize) 1146 args.push_back(mt.lookupValue(cpAsyncOp.getCpSize())); 1147 1148 return id; 1149 } 1150 1151 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims, 1152 bool isIm2Col) { 1153 switch (tensorDims) { 1154 case 1: 1155 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d; 1156 case 2: 1157 return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d; 1158 case 3: 1159 return isIm2Col 1160 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d 1161 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d; 1162 case 4: 1163 return isIm2Col 1164 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d 1165 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d; 1166 case 5: 1167 return isIm2Col 1168 ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d 1169 : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d; 1170 default: 1171 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp."); 1172 } 1173 } 1174 1175 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \ 1176 llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d 1177 1178 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \ 1179 is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \ 1180 : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile) 1181 1182 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \ 1183 [&]() -> auto { \ 1184 switch (dims) { \ 1185 case 1: \ 1186 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \ 1187 case 2: \ 1188 return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \ 1189 case 3: \ 1190 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \ 1191 case 4: \ 1192 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \ 1193 case 5: \ 1194 return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \ 1195 default: \ 1196 llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \ 1197 } \ 1198 }() 1199 1200 llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID( 1201 int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) { 1202 using RedTy = NVVM::TMAReduxKind; 1203 switch (kind) { 1204 case RedTy::ADD: 1205 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col); 1206 case RedTy::MIN: 1207 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col); 1208 case RedTy::MAX: 1209 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col); 1210 case RedTy::INC: 1211 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col); 1212 case RedTy::DEC: 1213 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col); 1214 case RedTy::AND: 1215 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col); 1216 case RedTy::OR: 1217 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col); 1218 case RedTy::XOR: 1219 return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col); 1220 } 1221 llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp"); 1222 } 1223 1224 llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, 1225 NVVM::SaturationMode sat, 1226 bool hasRelu) { 1227 using RndMode = NVVM::FPRoundingMode; 1228 switch (rnd) { 1229 case RndMode::RN: 1230 return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu 1231 : llvm::Intrinsic::nvvm_f2tf32_rn; 1232 case RndMode::RZ: 1233 return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu 1234 : llvm::Intrinsic::nvvm_f2tf32_rz; 1235 case RndMode::RNA: 1236 return (sat == NVVM::SaturationMode::SATFINITE) 1237 ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite 1238 : llvm::Intrinsic::nvvm_f2tf32_rna; 1239 default: 1240 llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op"); 1241 } 1242 } 1243 1244 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might 1245 /// have ConstantRangeAttr. 1246 static void nvvmInferResultRanges(Operation *op, Value result, 1247 ArrayRef<::mlir::ConstantIntRanges> argRanges, 1248 SetIntRangeFn setResultRanges) { 1249 if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) { 1250 setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(), 1251 rangeAttr.getLower(), rangeAttr.getUpper()}); 1252 } 1253 } 1254 1255 //===----------------------------------------------------------------------===// 1256 // NVVMDialect initialization, type parsing, and registration. 1257 //===----------------------------------------------------------------------===// 1258 1259 // TODO: This should be the llvm.nvvm dialect once this is supported. 1260 void NVVMDialect::initialize() { 1261 addOperations< 1262 #define GET_OP_LIST 1263 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" 1264 >(); 1265 addAttributes< 1266 #define GET_ATTRDEF_LIST 1267 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" 1268 >(); 1269 1270 // Support unknown operations because not all NVVM operations are 1271 // registered. 1272 allowUnknownOperations(); 1273 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>(); 1274 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>(); 1275 } 1276 1277 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, 1278 NamedAttribute attr) { 1279 StringAttr attrName = attr.getName(); 1280 // Kernel function attribute should be attached to functions. 1281 if (attrName == NVVMDialect::getKernelFuncAttrName()) { 1282 if (!isa<LLVM::LLVMFuncOp>(op)) { 1283 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() 1284 << "' attribute attached to unexpected op"; 1285 } 1286 } 1287 // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3 1288 // dim 1289 if (attrName == NVVMDialect::getMaxntidAttrName() || 1290 attrName == NVVMDialect::getReqntidAttrName() || 1291 attrName == NVVMDialect::getClusterDimAttrName()) { 1292 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue()); 1293 if (!values || values.empty() || values.size() > 3) 1294 return op->emitError() 1295 << "'" << attrName 1296 << "' attribute must be integer array with maximum 3 index"; 1297 } 1298 // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer 1299 // attribute 1300 if (attrName == NVVMDialect::getMinctasmAttrName() || 1301 attrName == NVVMDialect::getMaxnregAttrName() || 1302 attrName == NVVMDialect::getClusterMaxBlocksAttrName()) { 1303 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) 1304 return op->emitError() 1305 << "'" << attrName << "' attribute must be integer constant"; 1306 } 1307 1308 return success(); 1309 } 1310 1311 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op, 1312 unsigned regionIndex, 1313 unsigned argIndex, 1314 NamedAttribute argAttr) { 1315 auto funcOp = dyn_cast<FunctionOpInterface>(op); 1316 if (!funcOp) 1317 return success(); 1318 1319 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName()); 1320 StringAttr attrName = argAttr.getName(); 1321 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) { 1322 if (!isKernel) { 1323 return op->emitError() 1324 << "'" << attrName 1325 << "' attribute must be present only on kernel arguments"; 1326 } 1327 if (!isa<UnitAttr>(argAttr.getValue())) 1328 return op->emitError() << "'" << attrName << "' must be a unit attribute"; 1329 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) { 1330 return op->emitError() 1331 << "'" << attrName 1332 << "' attribute requires the argument to also have attribute '" 1333 << LLVM::LLVMDialect::getByValAttrName() << "'"; 1334 } 1335 } 1336 1337 return success(); 1338 } 1339 1340 //===----------------------------------------------------------------------===// 1341 // NVVM target attribute. 1342 //===----------------------------------------------------------------------===// 1343 LogicalResult 1344 NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1345 int optLevel, StringRef triple, StringRef chip, 1346 StringRef features, DictionaryAttr flags, 1347 ArrayAttr files) { 1348 if (optLevel < 0 || optLevel > 3) { 1349 emitError() << "The optimization level must be a number between 0 and 3."; 1350 return failure(); 1351 } 1352 if (triple.empty()) { 1353 emitError() << "The target triple cannot be empty."; 1354 return failure(); 1355 } 1356 if (chip.empty()) { 1357 emitError() << "The target chip cannot be empty."; 1358 return failure(); 1359 } 1360 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) { 1361 return attr && mlir::isa<StringAttr>(attr); 1362 })) { 1363 emitError() << "All the elements in the `link` array must be strings."; 1364 return failure(); 1365 } 1366 return success(); 1367 } 1368 1369 #define GET_OP_CLASSES 1370 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" 1371 1372 #define GET_ATTRDEF_CLASSES 1373 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" 1374