1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the NVVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 // The NVVM dialect only contains GPU specific additions on top of the general 13 // LLVM dialect. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 18 19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 20 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/IR/BuiltinAttributes.h" 24 #include "mlir/IR/BuiltinTypes.h" 25 #include "mlir/IR/Diagnostics.h" 26 #include "mlir/IR/DialectImplementation.h" 27 #include "mlir/IR/MLIRContext.h" 28 #include "mlir/IR/Operation.h" 29 #include "mlir/IR/OperationSupport.h" 30 #include "mlir/IR/Types.h" 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include "llvm/AsmParser/Parser.h" 34 #include "llvm/IR/Attributes.h" 35 #include "llvm/IR/Function.h" 36 #include "llvm/IR/Type.h" 37 #include "llvm/Support/Casting.h" 38 #include "llvm/Support/SourceMgr.h" 39 #include "llvm/Support/raw_ostream.h" 40 #include <cassert> 41 #include <optional> 42 #include <string> 43 44 using namespace mlir; 45 using namespace NVVM; 46 47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc" 48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc" 49 50 //===----------------------------------------------------------------------===// 51 // Printing/parsing for NVVM ops 52 //===----------------------------------------------------------------------===// 53 54 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) { 55 p << " " << op->getOperands(); 56 if (op->getNumResults() > 0) 57 p << " : " << op->getResultTypes(); 58 } 59 60 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type 61 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) { 62 MLIRContext *context = parser.getContext(); 63 auto int32Ty = IntegerType::get(context, 32); 64 auto int1Ty = IntegerType::get(context, 1); 65 66 SmallVector<OpAsmParser::UnresolvedOperand, 8> ops; 67 Type type; 68 return failure(parser.parseOperandList(ops) || 69 parser.parseOptionalAttrDict(result.attributes) || 70 parser.parseColonType(type) || 71 parser.addTypeToList(type, result.types) || 72 parser.resolveOperands(ops, {int32Ty, int1Ty}, 73 parser.getNameLoc(), result.operands)); 74 } 75 76 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); } 77 78 LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() { 79 if (getCoordinates().empty() || getCoordinates().size() > 5) 80 return emitError("expects coordinates between 1 to 5 dimension"); 81 82 // Check for im2col mode 83 if (!getIm2colOffsets().empty()) { 84 if (getCoordinates().size() < 3) 85 return emitError( 86 "to use im2col mode, the tensor has to be at least 3-dimensional"); 87 if (getCoordinates().size() != (getIm2colOffsets().size() + 2)) 88 return emitError( 89 "im2col offsets must be 2 less than number of coordinates"); 90 } 91 return success(); 92 } 93 94 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() { 95 if (getCoordinates().size() > 5) 96 return emitError("Maximum 5 coordinates and dimension is supported."); 97 return success(); 98 } 99 100 LogicalResult CpAsyncOp::verify() { 101 if (getModifier() != LoadCacheModifierKind::CG && 102 getModifier() != LoadCacheModifierKind::CA) 103 return emitError("Only CG and CA cache modifiers are supported."); 104 if (getSize() != 4 && getSize() != 8 && getSize() != 16) 105 return emitError("expected byte size to be either 4, 8 or 16."); 106 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16) 107 return emitError("CG cache modifier is only support for 16 bytes copy."); 108 return success(); 109 } 110 111 // Given the element type of an operand and whether or not it is an accumulator, 112 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the 113 // operand's element type. 114 std::optional<mlir::NVVM::MMATypes> 115 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) { 116 auto half2Type = 117 LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2); 118 if (operandElType.isF64()) 119 return NVVM::MMATypes::f64; 120 if (operandElType.isF16() || operandElType == half2Type) 121 return NVVM::MMATypes::f16; 122 if (operandElType.isF32() && isAccumulator) 123 return NVVM::MMATypes::f32; 124 if (operandElType.isF32() && !isAccumulator) 125 return NVVM::MMATypes::tf32; 126 if (llvm::isa<IntegerType>(operandElType)) { 127 if (isAccumulator) 128 return NVVM::MMATypes::s32; 129 return std::nullopt; 130 } 131 132 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) { 133 if (structType.getBody().empty()) 134 return std::nullopt; 135 return inferOperandMMAType(structType.getBody()[0], isAccumulator); 136 } 137 138 return std::nullopt; 139 } 140 141 static bool isInt4PtxType(MMATypes type) { 142 return (type == MMATypes::u4 || type == MMATypes::s4); 143 } 144 145 static bool isInt8PtxType(MMATypes type) { 146 return (type == MMATypes::u8 || type == MMATypes::s8); 147 } 148 149 static bool isIntegerPtxType(MMATypes type) { 150 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 || 151 type == MMATypes::s32; 152 } 153 154 MMATypes MmaOp::accumPtxType() { 155 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType( 156 getODSOperands(2).getTypes().front(), /*isAccum=*/true); 157 assert(val.has_value() && "accumulator PTX type should always be inferrable"); 158 return val.value(); 159 } 160 161 MMATypes MmaOp::resultPtxType() { 162 std::optional<mlir::NVVM::MMATypes> val = 163 inferOperandMMAType(getResult().getType(), /*isAccum=*/true); 164 assert(val.has_value() && "result PTX type should always be inferrable"); 165 return val.value(); 166 } 167 168 void MmaOp::print(OpAsmPrinter &p) { 169 SmallVector<Type, 4> regTypes; 170 struct OperandFragment { 171 StringRef operandName; 172 StringRef ptxTypeAttr; 173 SmallVector<Value, 4> regs; 174 explicit OperandFragment(StringRef name, StringRef ptxTypeName) 175 : operandName(name), ptxTypeAttr(ptxTypeName) {} 176 }; 177 178 std::array<OperandFragment, 3> frags{ 179 OperandFragment("A", getMultiplicandAPtxTypeAttrName()), 180 OperandFragment("B", getMultiplicandBPtxTypeAttrName()), 181 OperandFragment("C", "")}; 182 SmallVector<StringRef, 4> ignoreAttrNames{ 183 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()}; 184 185 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) { 186 auto &frag = frags[fragIdx]; 187 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx); 188 for (auto operandIdx = varOperandSpec.first; 189 operandIdx < varOperandSpec.first + varOperandSpec.second; 190 operandIdx++) { 191 frag.regs.push_back(this->getOperand(operandIdx)); 192 if (operandIdx == 0) { 193 regTypes.push_back(this->getOperand(operandIdx).getType()); 194 } 195 } 196 std::optional<MMATypes> inferredType = 197 inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2); 198 if (inferredType) 199 ignoreAttrNames.push_back(frag.ptxTypeAttr); 200 } 201 202 auto printMmaOperand = [&](const OperandFragment &frag) -> void { 203 p << " " << frag.operandName; 204 p << "["; 205 p.printOperands(frag.regs); 206 p << "] "; 207 }; 208 209 for (const auto &frag : frags) { 210 printMmaOperand(frag); 211 } 212 213 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); 214 215 // Print the types of the operands and result. 216 p << " : " << "("; 217 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(), 218 frags[1].regs[0].getType(), 219 frags[2].regs[0].getType()}, 220 p); 221 p << ")"; 222 p.printArrowTypeList(TypeRange{this->getRes().getType()}); 223 } 224 225 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType, 226 ValueRange operandA, ValueRange operandB, ValueRange operandC, 227 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op, 228 std::optional<MMAIntOverflow> intOverflow, 229 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes, 230 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) { 231 232 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); 233 MLIRContext *ctx = builder.getContext(); 234 result.addAttribute( 235 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2])); 236 237 result.addOperands(operandA); 238 result.addOperands(operandB); 239 result.addOperands(operandC); 240 241 if (multiplicandPtxTypes) { 242 result.addAttribute("multiplicandAPtxType", 243 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); 244 result.addAttribute("multiplicandBPtxType", 245 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); 246 } else { 247 if (auto res = inferOperandMMAType(operandA[0].getType(), false)) 248 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); 249 if (auto res = inferOperandMMAType(operandB[0].getType(), false)) 250 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); 251 } 252 253 if (multiplicandLayouts) { 254 result.addAttribute("layoutA", 255 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0])); 256 result.addAttribute("layoutB", 257 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1])); 258 } else { 259 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row)); 260 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col)); 261 } 262 263 if (intOverflow.has_value()) 264 result.addAttribute("intOverflowBehavior", 265 MMAIntOverflowAttr::get(ctx, *intOverflow)); 266 if (b1Op.has_value()) 267 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op)); 268 269 result.addTypes(resultType); 270 result.addAttribute( 271 MmaOp::getOperandSegmentSizeAttr(), 272 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()), 273 static_cast<int32_t>(operandB.size()), 274 static_cast<int32_t>(operandC.size())})); 275 } 276 277 // <operation> := 278 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]` 279 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0])) 280 // `->` type($res) 281 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) { 282 struct OperandFragment { 283 std::optional<MMATypes> elemtype; 284 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs; 285 SmallVector<Type> regTypes; 286 }; 287 288 Builder &builder = parser.getBuilder(); 289 std::array<OperandFragment, 4> frags; 290 291 NamedAttrList namedAttributes; 292 293 // A helper to parse the operand segments. 294 auto parseMmaOperand = [&](StringRef operandName, 295 OperandFragment &frag) -> LogicalResult { 296 if (parser.parseKeyword(operandName).failed()) 297 return failure(); 298 if (parser 299 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare) 300 .failed()) 301 return failure(); 302 return success(); 303 }; 304 305 // Parse the operand segments. 306 if (parseMmaOperand("A", frags[0]).failed()) 307 return failure(); 308 if (parseMmaOperand("B", frags[1]).failed()) 309 return failure(); 310 if (parseMmaOperand("C", frags[2]).failed()) 311 return failure(); 312 313 if (parser.parseOptionalAttrDict(namedAttributes).failed()) 314 return failure(); 315 316 // Parse the type specification and resolve operands. 317 SmallVector<Type, 3> operandTypes; 318 if (failed(parser.parseColon())) 319 return failure(); 320 if (failed(parser.parseLParen())) 321 return failure(); 322 if (failed(parser.parseTypeList(operandTypes))) 323 return failure(); 324 if (failed(parser.parseRParen())) 325 if (operandTypes.size() != 3) 326 return parser.emitError( 327 parser.getNameLoc(), 328 "expected one type for each operand segment but got " + 329 Twine(operandTypes.size()) + " types"); 330 for (const auto &iter : llvm::enumerate(operandTypes)) { 331 auto &frag = frags[iter.index()]; 332 frag.regTypes.resize(frag.regs.size(), iter.value()); 333 if (failed(parser.resolveOperands(frag.regs, frag.regTypes, 334 parser.getNameLoc(), result.operands))) 335 return failure(); 336 frag.elemtype = 337 inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2); 338 } 339 340 Type resultType; 341 if (parser.parseArrow() || parser.parseType(resultType)) 342 return failure(); 343 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true); 344 345 std::array<StringRef, 2> names{"multiplicandAPtxType", 346 "multiplicandBPtxType"}; 347 for (unsigned idx = 0; idx < names.size(); idx++) { 348 const auto &frag = frags[idx]; 349 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]); 350 if (!frag.elemtype.has_value() && !attr.has_value()) { 351 return parser.emitError( 352 parser.getNameLoc(), 353 "attribute " + names[idx] + 354 " is not provided explicitly and cannot be inferred"); 355 } 356 if (!attr.has_value()) 357 result.addAttribute( 358 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype)); 359 } 360 361 result.addTypes(resultType); 362 if (!namedAttributes.empty()) 363 result.addAttributes(namedAttributes); 364 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(), 365 builder.getDenseI32ArrayAttr({ 366 static_cast<int32_t>(frags[0].regs.size()), 367 static_cast<int32_t>(frags[1].regs.size()), 368 static_cast<int32_t>(frags[2].regs.size()), 369 })); 370 return success(); 371 } 372 373 LogicalResult MmaOp::verify() { 374 MLIRContext *context = getContext(); 375 auto f16Ty = Float16Type::get(context); 376 auto i32Ty = IntegerType::get(context, 32); 377 auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2); 378 auto f32Ty = Float32Type::get(context); 379 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( 380 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); 381 382 auto s32x4StructTy = 383 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty}); 384 auto f32x8StructTy = 385 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty)); 386 auto f16x2x2StructTy = 387 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty}); 388 auto f32x4StructTy = 389 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty}); 390 auto s32x2StructTy = 391 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); 392 393 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(), 394 getShapeAttr().getK()}; 395 396 // These variables define the set of allowed data types for matrices A, B, C, 397 // and result. 398 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>; 399 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>; 400 AllowedShapes allowedShapes; 401 AllowedTypes expectedA; 402 AllowedTypes expectedB; 403 AllowedTypes expectedC; 404 SmallVector<Type> expectedResult; 405 406 // When M = 16, we just need to calculate the number of 8xk tiles, where 407 // k is a factor that depends on the data type. 408 if (mmaShape[0] == 16) { 409 int64_t kFactor; 410 Type multiplicandFragType; 411 switch (*getMultiplicandAPtxType()) { 412 case MMATypes::tf32: 413 kFactor = 4; 414 multiplicandFragType = i32Ty; 415 expectedResult.push_back(LLVM::LLVMStructType::getLiteral( 416 context, {f32Ty, f32Ty, f32Ty, f32Ty})); 417 break; 418 case MMATypes::f16: 419 case MMATypes::bf16: 420 kFactor = 8; 421 multiplicandFragType = f16x2Ty; 422 expectedResult.push_back(f16x2x2StructTy); 423 expectedResult.push_back(f32x4StructTy); 424 break; 425 case MMATypes::s4: 426 case MMATypes::u4: 427 kFactor = 32; 428 break; 429 case MMATypes::b1: 430 kFactor = 128; 431 break; 432 case MMATypes::s8: 433 case MMATypes::u8: 434 kFactor = 16; 435 break; 436 default: 437 return emitError("invalid shape or multiplicand type: " + 438 stringifyEnum(getMultiplicandAPtxType().value())); 439 } 440 441 if (isIntegerPtxType(getMultiplicandAPtxType().value())) { 442 expectedResult.push_back(s32x4StructTy); 443 expectedC.emplace_back(4, i32Ty); 444 multiplicandFragType = i32Ty; 445 } else { 446 expectedC.emplace_back(2, f16x2Ty); 447 expectedC.emplace_back(4, f32Ty); 448 } 449 450 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor); 451 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor); 452 expectedA.emplace_back(unitA, multiplicandFragType); 453 expectedB.emplace_back(unitB, multiplicandFragType); 454 allowedShapes.push_back({16, 8, kFactor}); 455 allowedShapes.push_back({16, 8, kFactor * 2}); 456 } 457 458 // In the M=8 case, there is only 1 possible case per data type. 459 if (mmaShape[0] == 8) { 460 if (*getMultiplicandAPtxType() == MMATypes::f16) { 461 expectedA.emplace_back(2, f16x2Ty); 462 expectedB.emplace_back(2, f16x2Ty); 463 expectedResult.push_back(f16x2x4StructTy); 464 expectedResult.push_back(f32x8StructTy); 465 expectedC.emplace_back(4, f16x2Ty); 466 expectedC.emplace_back(8, f32Ty); 467 allowedShapes.push_back({8, 8, 4}); 468 } 469 if (*getMultiplicandAPtxType() == MMATypes::f64) { 470 Type f64Ty = Float64Type::get(context); 471 expectedA.emplace_back(1, f64Ty); 472 expectedB.emplace_back(1, f64Ty); 473 expectedC.emplace_back(2, f64Ty); 474 // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2)); 475 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( 476 context, SmallVector<Type>(2, f64Ty))); 477 allowedShapes.push_back({8, 8, 4}); 478 } 479 if (isIntegerPtxType(getMultiplicandAPtxType().value())) { 480 expectedA.push_back({i32Ty}); 481 expectedB.push_back({i32Ty}); 482 expectedC.push_back({i32Ty, i32Ty}); 483 expectedResult.push_back(s32x2StructTy); 484 if (isInt4PtxType(getMultiplicandAPtxType().value())) 485 allowedShapes.push_back({8, 8, 32}); 486 if (isInt8PtxType(getMultiplicandAPtxType().value())) 487 allowedShapes.push_back({8, 8, 16}); 488 if (getMultiplicandAPtxType().value() == MMATypes::b1) 489 allowedShapes.push_back({8, 8, 128}); 490 } 491 } 492 493 std::string errorMessage; 494 llvm::raw_string_ostream errorStream(errorMessage); 495 496 // Check that we matched an existing shape/dtype combination. 497 if (expectedA.empty() || expectedB.empty() || expectedC.empty() || 498 !llvm::is_contained(allowedShapes, mmaShape)) { 499 errorStream << "unimplemented variant for MMA shape <"; 500 llvm::interleaveComma(mmaShape, errorStream); 501 errorStream << ">"; 502 return emitOpError(errorMessage); 503 } 504 505 // Verify the operand types for segments of A, B, and C operands. 506 std::array<StringRef, 3> operandNames{"A", "B", "C"}; 507 for (const auto &iter : llvm::enumerate( 508 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) { 509 auto spec = this->getODSOperandIndexAndLength(iter.index()); 510 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first, 511 operand_type_begin() + spec.first + 512 spec.second); 513 bool match = llvm::is_contained(iter.value(), operandTySeg); 514 515 if (!match) { 516 errorStream << "Could not match types for the " 517 << operandNames[iter.index()] 518 << " operands; expected one of "; 519 for (const auto &x : iter.value()) { 520 errorStream << x.size() << "x" << x[0] << " "; 521 } 522 errorStream << "but got "; 523 llvm::interleaveComma(operandTySeg, errorStream); 524 return emitOpError(errorStream.str()); 525 } 526 } 527 528 // Check the result type 529 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) { 530 return expectedResultType == getResult().getType(); 531 })) { 532 errorStream 533 << "Could not match allowed types for the result; expected one of "; 534 llvm::interleaveComma(expectedResult, errorStream); 535 errorStream << " but got " << getResult().getType(); 536 return emitOpError(errorStream.str()); 537 } 538 539 // Ensure that binary MMA variants have a b1 MMA operation defined. 540 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) { 541 return emitOpError("op requires " + getB1OpAttrName().strref() + 542 " attribute"); 543 } 544 545 // Ensure int4/int8 MMA variants specify the accum overflow behavior 546 // attribute. 547 if (isInt4PtxType(*getMultiplicandAPtxType()) || 548 isInt8PtxType(*getMultiplicandAPtxType())) { 549 if (!getIntOverflowBehavior()) 550 return emitOpError("op requires " + 551 getIntOverflowBehaviorAttrName().strref() + 552 " attribute"); 553 } 554 555 return success(); 556 } 557 558 LogicalResult ShflOp::verify() { 559 if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid")) 560 return success(); 561 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); 562 auto elementType = (type && type.getBody().size() == 2) 563 ? llvm::dyn_cast<IntegerType>(type.getBody()[1]) 564 : nullptr; 565 if (!elementType || elementType.getWidth() != 1) 566 return emitError("expected return type to be a two-element struct with " 567 "i1 as the second element"); 568 return success(); 569 } 570 571 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type, 572 NVVM::MMAFrag frag, int nRow, 573 int nCol, 574 MLIRContext *context) { 575 unsigned numberElements = 0; 576 Type elementType; 577 OpBuilder builder(context); 578 Type f16x2 = VectorType::get(2, builder.getF16Type()); 579 if (type == NVVM::MMATypes::f16) { 580 elementType = f16x2; 581 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) 582 numberElements = 8; 583 else 584 numberElements = 4; 585 } else if (type == NVVM::MMATypes::f32) { 586 elementType = builder.getF32Type(); 587 numberElements = 8; 588 } else if (type == NVVM::MMATypes::tf32) { 589 elementType = builder.getI32Type(); 590 numberElements = 4; 591 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) { 592 elementType = builder.getI32Type(); 593 int parallelSize = 0; 594 if (frag == NVVM::MMAFrag::a) 595 parallelSize = nRow; 596 if (frag == NVVM::MMAFrag::b) 597 parallelSize = nCol; 598 599 // m == 16 && n == 16 && k == 16 600 if (parallelSize == 16) 601 numberElements = 2; 602 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16 603 else if (parallelSize == 8) 604 numberElements = 1; 605 else if (parallelSize == 32) 606 numberElements = 4; 607 } else if (type == NVVM::MMATypes::s32) { 608 elementType = builder.getI32Type(); 609 numberElements = 8; 610 } 611 assert(numberElements != 0 && elementType != nullptr); 612 return std::make_pair(elementType, numberElements); 613 } 614 615 static std::pair<mlir::Type, unsigned> 616 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, 617 int k, MLIRContext *context) { 618 int nRow, nCol; 619 if (frag == NVVM::MMAFrag::a) { 620 nRow = m; 621 nCol = k; 622 } else if (frag == NVVM::MMAFrag::b) { 623 nRow = k; 624 nCol = n; 625 } else { 626 nRow = m; 627 nCol = n; 628 } 629 assert(nRow && nCol); 630 return inferMMAType(type, frag, nRow, nCol, context); 631 } 632 633 LogicalResult NVVM::WMMALoadOp::verify() { 634 unsigned addressSpace = 635 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 636 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace && 637 addressSpace != NVVM::kSharedMemorySpace) 638 return emitOpError("expected source pointer in memory " 639 "space 0, 1, 3"); 640 641 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), 642 getEltype(), getFrag()) == 0) 643 return emitOpError() << "invalid attribute combination"; 644 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( 645 getEltype(), getFrag(), getM(), getN(), getK(), getContext()); 646 Type dstType = LLVM::LLVMStructType::getLiteral( 647 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); 648 if (getType() != dstType) 649 return emitOpError("expected destination type is a structure of ") 650 << typeInfo.second << " elements of type " << typeInfo.first; 651 return success(); 652 } 653 654 LogicalResult NVVM::WMMAStoreOp::verify() { 655 unsigned addressSpace = 656 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 657 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace && 658 addressSpace != NVVM::kSharedMemorySpace) 659 return emitOpError("expected operands to be a source pointer in memory " 660 "space 0, 1, 3"); 661 662 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), 663 getEltype()) == 0) 664 return emitOpError() << "invalid attribute combination"; 665 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( 666 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); 667 if (getArgs().size() != typeInfo.second) 668 return emitOpError() << "expected " << typeInfo.second << " data operands"; 669 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) { 670 return operands.getType() != typeInfo.first; 671 })) 672 return emitOpError() << "expected data operands of type " << typeInfo.first; 673 return success(); 674 } 675 676 LogicalResult NVVM::WMMAMmaOp::verify() { 677 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(), 678 getLayoutB(), getEltypeA(), 679 getEltypeB()) == 0) 680 return emitOpError() << "invalid attribute combination"; 681 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK( 682 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext()); 683 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK( 684 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext()); 685 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK( 686 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); 687 SmallVector<Type, 32> arguments; 688 arguments.append(typeInfoA.second, typeInfoA.first); 689 arguments.append(typeInfoB.second, typeInfoB.first); 690 arguments.append(typeInfoC.second, typeInfoC.first); 691 unsigned numArgs = arguments.size(); 692 if (getArgs().size() != numArgs) 693 return emitOpError() << "expected " << numArgs << " arguments"; 694 for (unsigned i = 0; i < numArgs; i++) { 695 if (getArgs()[i].getType() != arguments[i]) 696 return emitOpError() << "expected argument " << i << " to be of type " 697 << arguments[i]; 698 } 699 Type dstType = LLVM::LLVMStructType::getLiteral( 700 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first)); 701 if (getType() != dstType) 702 return emitOpError("expected destination type is a structure of ") 703 << typeInfoC.second << " elements of type " << typeInfoC.first; 704 return success(); 705 } 706 707 LogicalResult NVVM::LdMatrixOp::verify() { 708 unsigned addressSpace = 709 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 710 if (addressSpace != NVVM::kSharedMemorySpace) 711 return emitOpError("expected source pointer in memory space 3"); 712 713 if (getNum() != 1 && getNum() != 2 && getNum() != 4) 714 return emitOpError("expected num attribute to be 1, 2 or 4"); 715 716 Type i32 = IntegerType::get(getContext(), 32); 717 if (getNum() == 1 && getType() != i32) 718 return emitOpError("expected destination type is i32"); 719 if (getNum() == 2 || getNum() == 4) { 720 Type dstType = LLVM::LLVMStructType::getLiteral( 721 getContext(), SmallVector<Type>(getNum(), i32)); 722 if (getType() != dstType) 723 return emitOpError("expected destination type is a structure of ") 724 << getNum() << " elements of type i32"; 725 } 726 return success(); 727 } 728 729 LogicalResult NVVM::StMatrixOp::verify() { 730 unsigned addressSpace = 731 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 732 if (addressSpace != NVVM::kSharedMemorySpace) 733 return emitOpError("expected source pointer in memory space 3"); 734 735 int numMatrix = getSources().size(); 736 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4) 737 return emitOpError("expected num attribute to be 1, 2 or 4"); 738 739 return success(); 740 } 741 742 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) { 743 if (typeA == NVVM::WGMMATypes::tf32) 744 return 8; 745 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16) 746 return 16; 747 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8) 748 return 32; 749 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2) 750 return 32; 751 if (typeA == NVVM::WGMMATypes::b1) 752 return 256; 753 return failure(); 754 } 755 756 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, 757 NVVM::WGMMATypes typeA, 758 NVVM::WGMMATypes typeB) { 759 switch (typeA) { 760 case NVVM::WGMMATypes::f16: 761 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 762 typeB == NVVM::WGMMATypes::f16) 763 return success(); 764 break; 765 case NVVM::WGMMATypes::tf32: 766 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32) 767 return success(); 768 break; 769 case NVVM::WGMMATypes::u8: 770 case NVVM::WGMMATypes::s8: 771 if (typeD == NVVM::WGMMATypes::s32 && 772 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8)) 773 return success(); 774 break; 775 case NVVM::WGMMATypes::b1: 776 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1) 777 return success(); 778 break; 779 case NVVM::WGMMATypes::bf16: 780 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 781 typeB == NVVM::WGMMATypes::bf16) 782 return success(); 783 break; 784 case NVVM::WGMMATypes::e4m3: 785 case NVVM::WGMMATypes::e5m2: 786 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 787 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3)) 788 return success(); 789 break; 790 case WGMMATypes::f32: 791 case WGMMATypes::s32: 792 llvm_unreachable("unsupported input types"); 793 break; 794 } 795 return failure(); 796 } 797 798 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) { 799 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64, 800 72, 80, 88, 96, 104, 112, 120, 128, 801 136, 144, 152, 160, 168, 176, 184, 192, 802 200, 208, 216, 224, 232, 240, 248, 256}; 803 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64, 804 80, 96, 112, 128, 144, 160, 805 176, 192, 208, 224, 240, 256}; 806 switch (typeA) { 807 case WGMMATypes::f16: 808 case WGMMATypes::tf32: 809 case WGMMATypes::bf16: 810 case WGMMATypes::e4m3: 811 case WGMMATypes::e5m2: 812 if (llvm::is_contained(allowedN, sizeN)) 813 return success(); 814 break; 815 case WGMMATypes::u8: 816 case WGMMATypes::s8: 817 case WGMMATypes::b1: 818 if (llvm::is_contained(allowedNshort, sizeN)) 819 return success(); 820 break; 821 case WGMMATypes::f32: 822 case WGMMATypes::s32: 823 llvm_unreachable("unsupported input types"); 824 break; 825 } 826 return failure(); 827 } 828 829 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() { 830 Value outValue = getResults(); 831 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType()); 832 if (!stype) 833 return emitOpError() << "expected results to be struct"; 834 int outputSize = stype.getBody().size(); 835 WGMMATypes typeD = getTypeD(); 836 WGMMATypes typeA = getTypeA(); 837 WGMMATypes typeB = getTypeB(); 838 839 for (Type t : stype.getBody()) { 840 if (t != stype.getBody().front()) 841 return emitOpError() 842 << "all elements in struct must be same type but there is " << t; 843 } 844 845 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 && 846 typeD != WGMMATypes::s32) { 847 return emitOpError() << "does not support the given output type " 848 << NVVM::stringifyWGMMATypes(typeD); 849 } 850 if (typeD == WGMMATypes::s32 && 851 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) { 852 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg"; 853 } 854 855 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) { 856 return emitOpError() << NVVM::stringifyWGMMATypes(typeD) 857 << " += " << NVVM::stringifyWGMMATypes(typeA) << " * " 858 << NVVM::stringifyWGMMATypes(typeB) 859 << ", it is not supported."; 860 } 861 862 // Check M 863 if (getShape().getM() != 64) 864 return emitOpError() << "shape 'm' must be 64"; 865 866 // Check K 867 FailureOr<int> allowedK = getAllowedSizeK(typeA); 868 if (failed(allowedK) || allowedK.value() != getShape().getK()) 869 return emitOpError() << "shape 'k' must be " << allowedK.value() 870 << " for input type " 871 << NVVM::stringifyWGMMATypes(typeA); 872 873 // Check N 874 if (failed(isAllowedSizeN(getShape().getN(), typeA))) { 875 return emitOpError() << "has input type " 876 << NVVM::stringifyWGMMATypes(typeA) << " n is set to " 877 << getShape().getN() << ", it is not supported."; 878 } 879 880 // Check transpose (only available for f16/bf16) 881 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) && 882 (getLayoutA() == mlir::NVVM::MMALayout::col || 883 getLayoutB() == mlir::NVVM::MMALayout::col)) { 884 return emitOpError() 885 << "given layouts layout_a = " << stringifyMMALayout(getLayoutA()) 886 << " and layout_b = " << stringifyMMALayout(getLayoutB()) 887 << " for input types " << stringifyWGMMATypes(typeA) << " and " 888 << stringifyWGMMATypes(typeB) 889 << " requires transpose. However, this is only supported for: " 890 << stringifyMMATypes(MMATypes::f16) << " and " 891 << stringifyMMATypes(MMATypes::bf16); 892 } 893 894 // Check result registers 895 int expectedOutput = 0; 896 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32) 897 expectedOutput = getShape().getN() / 2; 898 if (typeD == WGMMATypes::f16) 899 expectedOutput = getShape().getN() / 4; 900 if (outputSize != expectedOutput) { 901 return emitOpError() << "results " << expectedOutput 902 << ", however output struct has " << outputSize 903 << " elements"; 904 } 905 // Check satfinite (only available for s32 accumulator) 906 if (typeD != WGMMATypes::s32 && 907 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == 908 NVVM::MMAIntOverflow::satfinite) { 909 return emitOpError() 910 << " `satfinite` can be only used with s32 accumulator, however " 911 "the current accumulator is " 912 << NVVM::stringifyWGMMATypes(typeD); 913 } 914 915 return success(); 916 } 917 918 std::string NVVM::WgmmaMmaAsyncOp::getPtx() { 919 920 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK(); 921 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; 922 923 StringRef outputTypeName = stringifyWGMMATypes(getTypeD()); 924 925 int expectedOutputRegisters = 0; 926 if (getTypeD() == WGMMATypes::f16) 927 expectedOutputRegisters = getShape().getN() / 4; 928 else 929 expectedOutputRegisters = getShape().getN() / 2; 930 931 std::string ptx; 932 llvm::raw_string_ostream ss(ptx); 933 934 ss << "{\n" 935 ".reg .pred p;\n" 936 "setp.ne.b32 p, $" 937 << ((expectedOutputRegisters * 2) + 2) 938 << ", 0;\n" 939 "wgmma.mma_async.sync.aligned.m" 940 << m << "n" << n << "k" << k << "." << outputTypeName << "." 941 << stringifyWGMMATypes(getTypeA()) << "." 942 << stringifyWGMMATypes(getTypeB()); 943 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == 944 NVVM::MMAIntOverflow::satfinite) 945 ss << ".satfinite"; 946 ss << " {"; 947 int regCnt = 0; 948 for (; regCnt < expectedOutputRegisters; ++regCnt) { 949 ss << "$" << regCnt; 950 if (regCnt != expectedOutputRegisters - 1) 951 ss << ", "; 952 } 953 954 ss << "},"; 955 // Need to map read/write registers correctly. 956 regCnt = (regCnt * 2); 957 ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p"; 958 if (getTypeD() != WGMMATypes::s32) { 959 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4); 960 } 961 // Don't add transpose parameters unless needed. 962 if (isF16) { 963 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6); 964 } 965 ss << ";\n" 966 << "}\n"; 967 ss.flush(); 968 return ptx; 969 } 970 971 void NVVM::WgmmaMmaAsyncOp::getAsmValues( 972 RewriterBase &rewriter, 973 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> 974 &asmValues) { 975 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; 976 if (getResults()) 977 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write}); 978 if (getInouts()) 979 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite}); 980 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read}); 981 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read}); 982 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())), 983 mlir::NVVM::PTXRegisterMod::Read}); 984 if (getTypeD() != WGMMATypes::s32) { 985 asmValues.push_back( 986 {makeConstantI32(rewriter, 987 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1), 988 mlir::NVVM::PTXRegisterMod::Read}); 989 asmValues.push_back( 990 {makeConstantI32(rewriter, 991 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1), 992 mlir::NVVM::PTXRegisterMod::Read}); 993 } 994 if (isF16) { 995 asmValues.push_back( 996 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())), 997 mlir::NVVM::PTXRegisterMod::Read}); 998 asmValues.push_back( 999 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())), 1000 mlir::NVVM::PTXRegisterMod::Read}); 1001 } 1002 } 1003 LogicalResult NVVM::FenceProxyOp::verify() { 1004 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) { 1005 return emitOpError() << "async_shared fence requires space attribute"; 1006 } 1007 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) { 1008 return emitOpError() << "only async_shared fence can have space attribute"; 1009 } 1010 return success(); 1011 } 1012 1013 LogicalResult NVVM::SetMaxRegisterOp::verify() { 1014 if (getRegCount() % 8) 1015 return emitOpError("new register size must be multiple of 8"); 1016 if (getRegCount() < 24 || getRegCount() > 256) 1017 return emitOpError("new register size must be in between 24 to 256"); 1018 return success(); 1019 } 1020 1021 LogicalResult NVVM::BarrierOp::verify() { 1022 if (getNumberOfThreads() && !getBarrierId()) 1023 return emitOpError( 1024 "barrier id is missing, it should be set between 0 to 15"); 1025 return success(); 1026 } 1027 1028 //===----------------------------------------------------------------------===// 1029 // NVVMDialect initialization, type parsing, and registration. 1030 //===----------------------------------------------------------------------===// 1031 1032 // TODO: This should be the llvm.nvvm dialect once this is supported. 1033 void NVVMDialect::initialize() { 1034 addOperations< 1035 #define GET_OP_LIST 1036 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" 1037 >(); 1038 addAttributes< 1039 #define GET_ATTRDEF_LIST 1040 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" 1041 >(); 1042 1043 // Support unknown operations because not all NVVM operations are 1044 // registered. 1045 allowUnknownOperations(); 1046 declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>(); 1047 declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>(); 1048 } 1049 1050 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, 1051 NamedAttribute attr) { 1052 StringAttr attrName = attr.getName(); 1053 // Kernel function attribute should be attached to functions. 1054 if (attrName == NVVMDialect::getKernelFuncAttrName()) { 1055 if (!isa<LLVM::LLVMFuncOp>(op)) { 1056 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() 1057 << "' attribute attached to unexpected op"; 1058 } 1059 } 1060 // If maxntid and reqntid exist, it must be an array with max 3 dim 1061 if (attrName == NVVMDialect::getMaxntidAttrName() || 1062 attrName == NVVMDialect::getReqntidAttrName()) { 1063 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue()); 1064 if (!values || values.empty() || values.size() > 3) 1065 return op->emitError() 1066 << "'" << attrName 1067 << "' attribute must be integer array with maximum 3 index"; 1068 } 1069 // If minctasm and maxnreg exist, it must be an integer attribute 1070 if (attrName == NVVMDialect::getMinctasmAttrName() || 1071 attrName == NVVMDialect::getMaxnregAttrName()) { 1072 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) 1073 return op->emitError() 1074 << "'" << attrName << "' attribute must be integer constant"; 1075 } 1076 1077 return success(); 1078 } 1079 1080 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op, 1081 unsigned regionIndex, 1082 unsigned argIndex, 1083 NamedAttribute argAttr) { 1084 auto funcOp = dyn_cast<FunctionOpInterface>(op); 1085 if (!funcOp) 1086 return success(); 1087 1088 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName()); 1089 StringAttr attrName = argAttr.getName(); 1090 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) { 1091 if (!isKernel) { 1092 return op->emitError() 1093 << "'" << attrName 1094 << "' attribute must be present only on kernel arguments"; 1095 } 1096 if (!isa<UnitAttr>(argAttr.getValue())) 1097 return op->emitError() << "'" << attrName << "' must be a unit attribute"; 1098 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) { 1099 return op->emitError() 1100 << "'" << attrName 1101 << "' attribute requires the argument to also have attribute '" 1102 << LLVM::LLVMDialect::getByValAttrName() << "'"; 1103 } 1104 } 1105 1106 return success(); 1107 } 1108 1109 //===----------------------------------------------------------------------===// 1110 // NVVM target attribute. 1111 //===----------------------------------------------------------------------===// 1112 LogicalResult 1113 NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1114 int optLevel, StringRef triple, StringRef chip, 1115 StringRef features, DictionaryAttr flags, 1116 ArrayAttr files) { 1117 if (optLevel < 0 || optLevel > 3) { 1118 emitError() << "The optimization level must be a number between 0 and 3."; 1119 return failure(); 1120 } 1121 if (triple.empty()) { 1122 emitError() << "The target triple cannot be empty."; 1123 return failure(); 1124 } 1125 if (chip.empty()) { 1126 emitError() << "The target chip cannot be empty."; 1127 return failure(); 1128 } 1129 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) { 1130 return attr && mlir::isa<StringAttr>(attr); 1131 })) { 1132 emitError() << "All the elements in the `link` array must be strings."; 1133 return failure(); 1134 } 1135 return success(); 1136 } 1137 1138 #define GET_OP_CLASSES 1139 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" 1140 1141 #define GET_ATTRDEF_CLASSES 1142 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" 1143