1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the NVVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 // The NVVM dialect only contains GPU specific additions on top of the general 13 // LLVM dialect. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 18 19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 20 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" 21 #include "mlir/Dialect/Utils/StaticValueUtils.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/IR/BuiltinAttributes.h" 24 #include "mlir/IR/BuiltinTypes.h" 25 #include "mlir/IR/Diagnostics.h" 26 #include "mlir/IR/DialectImplementation.h" 27 #include "mlir/IR/MLIRContext.h" 28 #include "mlir/IR/Operation.h" 29 #include "mlir/IR/OperationSupport.h" 30 #include "mlir/IR/Types.h" 31 #include "mlir/Support/LogicalResult.h" 32 #include "llvm/ADT/STLExtras.h" 33 #include "llvm/ADT/TypeSwitch.h" 34 #include "llvm/AsmParser/Parser.h" 35 #include "llvm/IR/Attributes.h" 36 #include "llvm/IR/Function.h" 37 #include "llvm/IR/Type.h" 38 #include "llvm/Support/Casting.h" 39 #include "llvm/Support/SourceMgr.h" 40 #include "llvm/Support/raw_ostream.h" 41 #include <cassert> 42 #include <optional> 43 #include <string> 44 45 using namespace mlir; 46 using namespace NVVM; 47 48 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc" 49 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc" 50 51 //===----------------------------------------------------------------------===// 52 // Printing/parsing for NVVM ops 53 //===----------------------------------------------------------------------===// 54 55 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) { 56 p << " " << op->getOperands(); 57 if (op->getNumResults() > 0) 58 p << " : " << op->getResultTypes(); 59 } 60 61 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type 62 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) { 63 MLIRContext *context = parser.getContext(); 64 auto int32Ty = IntegerType::get(context, 32); 65 auto int1Ty = IntegerType::get(context, 1); 66 67 SmallVector<OpAsmParser::UnresolvedOperand, 8> ops; 68 Type type; 69 return failure(parser.parseOperandList(ops) || 70 parser.parseOptionalAttrDict(result.attributes) || 71 parser.parseColonType(type) || 72 parser.addTypeToList(type, result.types) || 73 parser.resolveOperands(ops, {int32Ty, int1Ty}, 74 parser.getNameLoc(), result.operands)); 75 } 76 77 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); } 78 79 LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() { 80 if (getCoordinates().empty() || getCoordinates().size() > 5) 81 return emitError("expects coordinates between 1 to 5 dimension"); 82 83 // Check for im2col mode 84 if (!getIm2colOffsets().empty()) { 85 if (getCoordinates().size() < 3) 86 return emitError( 87 "to use im2col mode, the tensor has to be at least 3-dimensional"); 88 if (getCoordinates().size() != (getIm2colOffsets().size() + 2)) 89 return emitError( 90 "im2col offsets must be 2 less than number of coordinates"); 91 } 92 return success(); 93 } 94 95 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() { 96 if (getCoordinates().size() > 5) 97 return emitError("Maximum 5 coordinates and dimension is supported."); 98 return success(); 99 } 100 101 LogicalResult CpAsyncOp::verify() { 102 if (getModifier() != LoadCacheModifierKind::CG && 103 getModifier() != LoadCacheModifierKind::CA) 104 return emitError("Only CG and CA cache modifiers are supported."); 105 if (getSize() != 4 && getSize() != 8 && getSize() != 16) 106 return emitError("expected byte size to be either 4, 8 or 16."); 107 if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16) 108 return emitError("CG cache modifier is only support for 16 bytes copy."); 109 return success(); 110 } 111 112 // Given the element type of an operand and whether or not it is an accumulator, 113 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the 114 // operand's element type. 115 std::optional<mlir::NVVM::MMATypes> 116 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) { 117 auto half2Type = 118 LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2); 119 if (operandElType.isF64()) 120 return NVVM::MMATypes::f64; 121 if (operandElType.isF16() || operandElType == half2Type) 122 return NVVM::MMATypes::f16; 123 if (operandElType.isF32() && isAccumulator) 124 return NVVM::MMATypes::f32; 125 if (operandElType.isF32() && !isAccumulator) 126 return NVVM::MMATypes::tf32; 127 if (llvm::isa<IntegerType>(operandElType)) { 128 if (isAccumulator) 129 return NVVM::MMATypes::s32; 130 return std::nullopt; 131 } 132 133 if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) { 134 if (structType.getBody().empty()) 135 return std::nullopt; 136 return inferOperandMMAType(structType.getBody()[0], isAccumulator); 137 } 138 139 return std::nullopt; 140 } 141 142 static bool isInt4PtxType(MMATypes type) { 143 return (type == MMATypes::u4 || type == MMATypes::s4); 144 } 145 146 static bool isInt8PtxType(MMATypes type) { 147 return (type == MMATypes::u8 || type == MMATypes::s8); 148 } 149 150 static bool isIntegerPtxType(MMATypes type) { 151 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 || 152 type == MMATypes::s32; 153 } 154 155 MMATypes MmaOp::accumPtxType() { 156 std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType( 157 getODSOperands(2).getTypes().front(), /*isAccum=*/true); 158 assert(val.has_value() && "accumulator PTX type should always be inferrable"); 159 return val.value(); 160 } 161 162 MMATypes MmaOp::resultPtxType() { 163 std::optional<mlir::NVVM::MMATypes> val = 164 inferOperandMMAType(getResult().getType(), /*isAccum=*/true); 165 assert(val.has_value() && "result PTX type should always be inferrable"); 166 return val.value(); 167 } 168 169 void MmaOp::print(OpAsmPrinter &p) { 170 SmallVector<Type, 4> regTypes; 171 struct OperandFragment { 172 StringRef operandName; 173 StringRef ptxTypeAttr; 174 SmallVector<Value, 4> regs; 175 explicit OperandFragment(StringRef name, StringRef ptxTypeName) 176 : operandName(name), ptxTypeAttr(ptxTypeName) {} 177 }; 178 179 std::array<OperandFragment, 3> frags{ 180 OperandFragment("A", getMultiplicandAPtxTypeAttrName()), 181 OperandFragment("B", getMultiplicandBPtxTypeAttrName()), 182 OperandFragment("C", "")}; 183 SmallVector<StringRef, 4> ignoreAttrNames{ 184 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()}; 185 186 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) { 187 auto &frag = frags[fragIdx]; 188 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx); 189 for (auto operandIdx = varOperandSpec.first; 190 operandIdx < varOperandSpec.first + varOperandSpec.second; 191 operandIdx++) { 192 frag.regs.push_back(this->getOperand(operandIdx)); 193 if (operandIdx == 0) { 194 regTypes.push_back(this->getOperand(operandIdx).getType()); 195 } 196 } 197 std::optional<MMATypes> inferredType = 198 inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2); 199 if (inferredType) 200 ignoreAttrNames.push_back(frag.ptxTypeAttr); 201 } 202 203 auto printMmaOperand = [&](const OperandFragment &frag) -> void { 204 p << " " << frag.operandName; 205 p << "["; 206 p.printOperands(frag.regs); 207 p << "] "; 208 }; 209 210 for (const auto &frag : frags) { 211 printMmaOperand(frag); 212 } 213 214 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); 215 216 // Print the types of the operands and result. 217 p << " : " 218 << "("; 219 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(), 220 frags[1].regs[0].getType(), 221 frags[2].regs[0].getType()}, 222 p); 223 p << ")"; 224 p.printArrowTypeList(TypeRange{this->getRes().getType()}); 225 } 226 227 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType, 228 ValueRange operandA, ValueRange operandB, ValueRange operandC, 229 ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op, 230 std::optional<MMAIntOverflow> intOverflow, 231 std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes, 232 std::optional<std::array<MMALayout, 2>> multiplicandLayouts) { 233 234 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); 235 MLIRContext *ctx = builder.getContext(); 236 result.addAttribute( 237 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2])); 238 239 result.addOperands(operandA); 240 result.addOperands(operandB); 241 result.addOperands(operandC); 242 243 if (multiplicandPtxTypes) { 244 result.addAttribute("multiplicandAPtxType", 245 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); 246 result.addAttribute("multiplicandBPtxType", 247 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); 248 } else { 249 if (auto res = inferOperandMMAType(operandA[0].getType(), false)) 250 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); 251 if (auto res = inferOperandMMAType(operandB[0].getType(), false)) 252 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); 253 } 254 255 if (multiplicandLayouts) { 256 result.addAttribute("layoutA", 257 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0])); 258 result.addAttribute("layoutB", 259 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1])); 260 } else { 261 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row)); 262 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col)); 263 } 264 265 if (intOverflow.has_value()) 266 result.addAttribute("intOverflowBehavior", 267 MMAIntOverflowAttr::get(ctx, *intOverflow)); 268 if (b1Op.has_value()) 269 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op)); 270 271 result.addTypes(resultType); 272 result.addAttribute( 273 MmaOp::getOperandSegmentSizeAttr(), 274 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()), 275 static_cast<int32_t>(operandB.size()), 276 static_cast<int32_t>(operandC.size())})); 277 } 278 279 // <operation> := 280 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]` 281 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0])) 282 // `->` type($res) 283 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) { 284 struct OperandFragment { 285 std::optional<MMATypes> elemtype; 286 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs; 287 SmallVector<Type> regTypes; 288 }; 289 290 Builder &builder = parser.getBuilder(); 291 std::array<OperandFragment, 4> frags; 292 293 NamedAttrList namedAttributes; 294 295 // A helper to parse the operand segments. 296 auto parseMmaOperand = [&](StringRef operandName, 297 OperandFragment &frag) -> LogicalResult { 298 if (parser.parseKeyword(operandName).failed()) 299 return failure(); 300 if (parser 301 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare) 302 .failed()) 303 return failure(); 304 return success(); 305 }; 306 307 // Parse the operand segments. 308 if (parseMmaOperand("A", frags[0]).failed()) 309 return failure(); 310 if (parseMmaOperand("B", frags[1]).failed()) 311 return failure(); 312 if (parseMmaOperand("C", frags[2]).failed()) 313 return failure(); 314 315 if (parser.parseOptionalAttrDict(namedAttributes).failed()) 316 return failure(); 317 318 // Parse the type specification and resolve operands. 319 SmallVector<Type, 3> operandTypes; 320 if (failed(parser.parseColon())) 321 return failure(); 322 if (failed(parser.parseLParen())) 323 return failure(); 324 if (failed(parser.parseTypeList(operandTypes))) 325 return failure(); 326 if (failed(parser.parseRParen())) 327 if (operandTypes.size() != 3) 328 return parser.emitError( 329 parser.getNameLoc(), 330 "expected one type for each operand segment but got " + 331 Twine(operandTypes.size()) + " types"); 332 for (const auto &iter : llvm::enumerate(operandTypes)) { 333 auto &frag = frags[iter.index()]; 334 frag.regTypes.resize(frag.regs.size(), iter.value()); 335 if (failed(parser.resolveOperands(frag.regs, frag.regTypes, 336 parser.getNameLoc(), result.operands))) 337 return failure(); 338 frag.elemtype = 339 inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2); 340 } 341 342 Type resultType; 343 if (parser.parseArrow() || parser.parseType(resultType)) 344 return failure(); 345 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true); 346 347 std::array<StringRef, 2> names{"multiplicandAPtxType", 348 "multiplicandBPtxType"}; 349 for (unsigned idx = 0; idx < names.size(); idx++) { 350 const auto &frag = frags[idx]; 351 std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]); 352 if (!frag.elemtype.has_value() && !attr.has_value()) { 353 return parser.emitError( 354 parser.getNameLoc(), 355 "attribute " + names[idx] + 356 " is not provided explicitly and cannot be inferred"); 357 } 358 if (!attr.has_value()) 359 result.addAttribute( 360 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype)); 361 } 362 363 result.addTypes(resultType); 364 if (!namedAttributes.empty()) 365 result.addAttributes(namedAttributes); 366 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(), 367 builder.getDenseI32ArrayAttr({ 368 static_cast<int32_t>(frags[0].regs.size()), 369 static_cast<int32_t>(frags[1].regs.size()), 370 static_cast<int32_t>(frags[2].regs.size()), 371 })); 372 return success(); 373 } 374 375 LogicalResult MmaOp::verify() { 376 MLIRContext *context = getContext(); 377 auto f16Ty = Float16Type::get(context); 378 auto i32Ty = IntegerType::get(context, 32); 379 auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2); 380 auto f32Ty = Float32Type::get(context); 381 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( 382 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); 383 384 auto s32x4StructTy = 385 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty}); 386 auto f32x8StructTy = 387 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty)); 388 auto f16x2x2StructTy = 389 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty}); 390 auto f32x4StructTy = 391 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty}); 392 auto s32x2StructTy = 393 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); 394 395 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(), 396 getShapeAttr().getK()}; 397 398 // These variables define the set of allowed data types for matrices A, B, C, 399 // and result. 400 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>; 401 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>; 402 AllowedShapes allowedShapes; 403 AllowedTypes expectedA; 404 AllowedTypes expectedB; 405 AllowedTypes expectedC; 406 SmallVector<Type> expectedResult; 407 408 // When M = 16, we just need to calculate the number of 8xk tiles, where 409 // k is a factor that depends on the data type. 410 if (mmaShape[0] == 16) { 411 int64_t kFactor; 412 Type multiplicandFragType; 413 switch (*getMultiplicandAPtxType()) { 414 case MMATypes::tf32: 415 kFactor = 4; 416 multiplicandFragType = i32Ty; 417 expectedResult.push_back(LLVM::LLVMStructType::getLiteral( 418 context, {f32Ty, f32Ty, f32Ty, f32Ty})); 419 break; 420 case MMATypes::f16: 421 case MMATypes::bf16: 422 kFactor = 8; 423 multiplicandFragType = f16x2Ty; 424 expectedResult.push_back(f16x2x2StructTy); 425 expectedResult.push_back(f32x4StructTy); 426 break; 427 case MMATypes::s4: 428 case MMATypes::u4: 429 kFactor = 32; 430 break; 431 case MMATypes::b1: 432 kFactor = 128; 433 break; 434 case MMATypes::s8: 435 case MMATypes::u8: 436 kFactor = 16; 437 break; 438 default: 439 return emitError("invalid shape or multiplicand type: " + 440 stringifyEnum(getMultiplicandAPtxType().value())); 441 } 442 443 if (isIntegerPtxType(getMultiplicandAPtxType().value())) { 444 expectedResult.push_back(s32x4StructTy); 445 expectedC.emplace_back(4, i32Ty); 446 multiplicandFragType = i32Ty; 447 } else { 448 expectedC.emplace_back(2, f16x2Ty); 449 expectedC.emplace_back(4, f32Ty); 450 } 451 452 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor); 453 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor); 454 expectedA.emplace_back(unitA, multiplicandFragType); 455 expectedB.emplace_back(unitB, multiplicandFragType); 456 allowedShapes.push_back({16, 8, kFactor}); 457 allowedShapes.push_back({16, 8, kFactor * 2}); 458 } 459 460 // In the M=8 case, there is only 1 possible case per data type. 461 if (mmaShape[0] == 8) { 462 if (*getMultiplicandAPtxType() == MMATypes::f16) { 463 expectedA.emplace_back(2, f16x2Ty); 464 expectedB.emplace_back(2, f16x2Ty); 465 expectedResult.push_back(f16x2x4StructTy); 466 expectedResult.push_back(f32x8StructTy); 467 expectedC.emplace_back(4, f16x2Ty); 468 expectedC.emplace_back(8, f32Ty); 469 allowedShapes.push_back({8, 8, 4}); 470 } 471 if (*getMultiplicandAPtxType() == MMATypes::f64) { 472 Type f64Ty = Float64Type::get(context); 473 expectedA.emplace_back(1, f64Ty); 474 expectedB.emplace_back(1, f64Ty); 475 expectedC.emplace_back(2, f64Ty); 476 // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2)); 477 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( 478 context, SmallVector<Type>(2, f64Ty))); 479 allowedShapes.push_back({8, 8, 4}); 480 } 481 if (isIntegerPtxType(getMultiplicandAPtxType().value())) { 482 expectedA.push_back({i32Ty}); 483 expectedB.push_back({i32Ty}); 484 expectedC.push_back({i32Ty, i32Ty}); 485 expectedResult.push_back(s32x2StructTy); 486 if (isInt4PtxType(getMultiplicandAPtxType().value())) 487 allowedShapes.push_back({8, 8, 32}); 488 if (isInt8PtxType(getMultiplicandAPtxType().value())) 489 allowedShapes.push_back({8, 8, 16}); 490 if (getMultiplicandAPtxType().value() == MMATypes::b1) 491 allowedShapes.push_back({8, 8, 128}); 492 } 493 } 494 495 std::string errorMessage; 496 llvm::raw_string_ostream errorStream(errorMessage); 497 498 // Check that we matched an existing shape/dtype combination. 499 if (expectedA.empty() || expectedB.empty() || expectedC.empty() || 500 !llvm::is_contained(allowedShapes, mmaShape)) { 501 errorStream << "unimplemented variant for MMA shape <"; 502 llvm::interleaveComma(mmaShape, errorStream); 503 errorStream << ">"; 504 return emitOpError(errorMessage); 505 } 506 507 // Verify the operand types for segments of A, B, and C operands. 508 std::array<StringRef, 3> operandNames{"A", "B", "C"}; 509 for (const auto &iter : llvm::enumerate( 510 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) { 511 auto spec = this->getODSOperandIndexAndLength(iter.index()); 512 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first, 513 operand_type_begin() + spec.first + 514 spec.second); 515 bool match = llvm::is_contained(iter.value(), operandTySeg); 516 517 if (!match) { 518 errorStream << "Could not match types for the " 519 << operandNames[iter.index()] 520 << " operands; expected one of "; 521 for (const auto &x : iter.value()) { 522 errorStream << x.size() << "x" << x[0] << " "; 523 } 524 errorStream << "but got "; 525 llvm::interleaveComma(operandTySeg, errorStream); 526 return emitOpError(errorStream.str()); 527 } 528 } 529 530 // Check the result type 531 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) { 532 return expectedResultType == getResult().getType(); 533 })) { 534 errorStream 535 << "Could not match allowed types for the result; expected one of "; 536 llvm::interleaveComma(expectedResult, errorStream); 537 errorStream << " but got " << getResult().getType(); 538 return emitOpError(errorStream.str()); 539 } 540 541 // Ensure that binary MMA variants have a b1 MMA operation defined. 542 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) { 543 return emitOpError("op requires " + getB1OpAttrName().strref() + 544 " attribute"); 545 } 546 547 // Ensure int4/int8 MMA variants specify the accum overflow behavior 548 // attribute. 549 if (isInt4PtxType(*getMultiplicandAPtxType()) || 550 isInt8PtxType(*getMultiplicandAPtxType())) { 551 if (!getIntOverflowBehavior()) 552 return emitOpError("op requires " + 553 getIntOverflowBehaviorAttrName().strref() + 554 " attribute"); 555 } 556 557 return success(); 558 } 559 560 LogicalResult ShflOp::verify() { 561 if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid")) 562 return success(); 563 auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); 564 auto elementType = (type && type.getBody().size() == 2) 565 ? llvm::dyn_cast<IntegerType>(type.getBody()[1]) 566 : nullptr; 567 if (!elementType || elementType.getWidth() != 1) 568 return emitError("expected return type to be a two-element struct with " 569 "i1 as the second element"); 570 return success(); 571 } 572 573 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type, 574 NVVM::MMAFrag frag, int nRow, 575 int nCol, 576 MLIRContext *context) { 577 unsigned numberElements = 0; 578 Type elementType; 579 OpBuilder builder(context); 580 Type f16x2 = VectorType::get(2, builder.getF16Type()); 581 if (type == NVVM::MMATypes::f16) { 582 elementType = f16x2; 583 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) 584 numberElements = 8; 585 else 586 numberElements = 4; 587 } else if (type == NVVM::MMATypes::f32) { 588 elementType = builder.getF32Type(); 589 numberElements = 8; 590 } else if (type == NVVM::MMATypes::tf32) { 591 elementType = builder.getI32Type(); 592 numberElements = 4; 593 } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) { 594 elementType = builder.getI32Type(); 595 int parallelSize = 0; 596 if (frag == NVVM::MMAFrag::a) 597 parallelSize = nRow; 598 if (frag == NVVM::MMAFrag::b) 599 parallelSize = nCol; 600 601 // m == 16 && n == 16 && k == 16 602 if (parallelSize == 16) 603 numberElements = 2; 604 // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16 605 else if (parallelSize == 8) 606 numberElements = 1; 607 else if (parallelSize == 32) 608 numberElements = 4; 609 } else if (type == NVVM::MMATypes::s32) { 610 elementType = builder.getI32Type(); 611 numberElements = 8; 612 } 613 assert(numberElements != 0 && elementType != nullptr); 614 return std::make_pair(elementType, numberElements); 615 } 616 617 static std::pair<mlir::Type, unsigned> 618 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, 619 int k, MLIRContext *context) { 620 int nRow, nCol; 621 if (frag == NVVM::MMAFrag::a) { 622 nRow = m; 623 nCol = k; 624 } else if (frag == NVVM::MMAFrag::b) { 625 nRow = k; 626 nCol = n; 627 } else { 628 nRow = m; 629 nCol = n; 630 } 631 assert(nRow && nCol); 632 return inferMMAType(type, frag, nRow, nCol, context); 633 } 634 635 LogicalResult NVVM::WMMALoadOp::verify() { 636 unsigned addressSpace = 637 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 638 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace && 639 addressSpace != NVVM::kSharedMemorySpace) 640 return emitOpError("expected source pointer in memory " 641 "space 0, 1, 3"); 642 643 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), 644 getEltype(), getFrag()) == 0) 645 return emitOpError() << "invalid attribute combination"; 646 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( 647 getEltype(), getFrag(), getM(), getN(), getK(), getContext()); 648 Type dstType = LLVM::LLVMStructType::getLiteral( 649 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); 650 if (getType() != dstType) 651 return emitOpError("expected destination type is a structure of ") 652 << typeInfo.second << " elements of type " << typeInfo.first; 653 return success(); 654 } 655 656 LogicalResult NVVM::WMMAStoreOp::verify() { 657 unsigned addressSpace = 658 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 659 if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace && 660 addressSpace != NVVM::kSharedMemorySpace) 661 return emitOpError("expected operands to be a source pointer in memory " 662 "space 0, 1, 3"); 663 664 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(), 665 getEltype()) == 0) 666 return emitOpError() << "invalid attribute combination"; 667 std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( 668 getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); 669 if (getArgs().size() != typeInfo.second) 670 return emitOpError() << "expected " << typeInfo.second << " data operands"; 671 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) { 672 return operands.getType() != typeInfo.first; 673 })) 674 return emitOpError() << "expected data operands of type " << typeInfo.first; 675 return success(); 676 } 677 678 LogicalResult NVVM::WMMAMmaOp::verify() { 679 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(), 680 getLayoutB(), getEltypeA(), 681 getEltypeB()) == 0) 682 return emitOpError() << "invalid attribute combination"; 683 std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK( 684 getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext()); 685 std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK( 686 getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext()); 687 std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK( 688 getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext()); 689 SmallVector<Type, 32> arguments; 690 arguments.append(typeInfoA.second, typeInfoA.first); 691 arguments.append(typeInfoB.second, typeInfoB.first); 692 arguments.append(typeInfoC.second, typeInfoC.first); 693 unsigned numArgs = arguments.size(); 694 if (getArgs().size() != numArgs) 695 return emitOpError() << "expected " << numArgs << " arguments"; 696 for (unsigned i = 0; i < numArgs; i++) { 697 if (getArgs()[i].getType() != arguments[i]) 698 return emitOpError() << "expected argument " << i << " to be of type " 699 << arguments[i]; 700 } 701 Type dstType = LLVM::LLVMStructType::getLiteral( 702 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first)); 703 if (getType() != dstType) 704 return emitOpError("expected destination type is a structure of ") 705 << typeInfoC.second << " elements of type " << typeInfoC.first; 706 return success(); 707 } 708 709 LogicalResult NVVM::LdMatrixOp::verify() { 710 unsigned addressSpace = 711 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 712 if (addressSpace != NVVM::kSharedMemorySpace) 713 return emitOpError("expected source pointer in memory space 3"); 714 715 if (getNum() != 1 && getNum() != 2 && getNum() != 4) 716 return emitOpError("expected num attribute to be 1, 2 or 4"); 717 718 Type i32 = IntegerType::get(getContext(), 32); 719 if (getNum() == 1 && getType() != i32) 720 return emitOpError("expected destination type is i32"); 721 if (getNum() == 2 || getNum() == 4) { 722 Type dstType = LLVM::LLVMStructType::getLiteral( 723 getContext(), SmallVector<Type>(getNum(), i32)); 724 if (getType() != dstType) 725 return emitOpError("expected destination type is a structure of ") 726 << getNum() << " elements of type i32"; 727 } 728 return success(); 729 } 730 731 LogicalResult NVVM::StMatrixOp::verify() { 732 unsigned addressSpace = 733 llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); 734 if (addressSpace != NVVM::kSharedMemorySpace) 735 return emitOpError("expected source pointer in memory space 3"); 736 737 int numMatrix = getSources().size(); 738 if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4) 739 return emitOpError("expected num attribute to be 1, 2 or 4"); 740 741 return success(); 742 } 743 744 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) { 745 if (typeA == NVVM::WGMMATypes::tf32) 746 return 8; 747 if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16) 748 return 16; 749 if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8) 750 return 32; 751 if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2) 752 return 32; 753 if (typeA == NVVM::WGMMATypes::b1) 754 return 256; 755 return failure(); 756 } 757 758 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD, 759 NVVM::WGMMATypes typeA, 760 NVVM::WGMMATypes typeB) { 761 switch (typeA) { 762 case NVVM::WGMMATypes::f16: 763 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 764 typeB == NVVM::WGMMATypes::f16) 765 return success(); 766 break; 767 case NVVM::WGMMATypes::tf32: 768 if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32) 769 return success(); 770 break; 771 case NVVM::WGMMATypes::u8: 772 case NVVM::WGMMATypes::s8: 773 if (typeD == NVVM::WGMMATypes::s32 && 774 (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8)) 775 return success(); 776 break; 777 case NVVM::WGMMATypes::b1: 778 if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1) 779 return success(); 780 break; 781 case NVVM::WGMMATypes::bf16: 782 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 783 typeB == NVVM::WGMMATypes::bf16) 784 return success(); 785 break; 786 case NVVM::WGMMATypes::e4m3: 787 case NVVM::WGMMATypes::e5m2: 788 if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) && 789 (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3)) 790 return success(); 791 break; 792 case WGMMATypes::f32: 793 case WGMMATypes::s32: 794 llvm_unreachable("unsupported input types"); 795 break; 796 } 797 return failure(); 798 } 799 800 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) { 801 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64, 802 72, 80, 88, 96, 104, 112, 120, 128, 803 136, 144, 152, 160, 168, 176, 184, 192, 804 200, 208, 216, 224, 232, 240, 248, 256}; 805 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64, 806 80, 96, 112, 128, 144, 160, 807 176, 192, 208, 224, 240, 256}; 808 switch (typeA) { 809 case WGMMATypes::f16: 810 case WGMMATypes::tf32: 811 case WGMMATypes::bf16: 812 case WGMMATypes::e4m3: 813 case WGMMATypes::e5m2: 814 if (llvm::is_contained(allowedN, sizeN)) 815 return success(); 816 break; 817 case WGMMATypes::u8: 818 case WGMMATypes::s8: 819 case WGMMATypes::b1: 820 if (llvm::is_contained(allowedNshort, sizeN)) 821 return success(); 822 break; 823 case WGMMATypes::f32: 824 case WGMMATypes::s32: 825 llvm_unreachable("unsupported input types"); 826 break; 827 } 828 return failure(); 829 } 830 831 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() { 832 Value outValue = getResults(); 833 auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType()); 834 if (!stype) 835 return emitOpError() << "expected results to be struct"; 836 int outputSize = stype.getBody().size(); 837 WGMMATypes typeD = getTypeD(); 838 WGMMATypes typeA = getTypeA(); 839 WGMMATypes typeB = getTypeB(); 840 841 for (Type t : stype.getBody()) { 842 if (t != stype.getBody().front()) 843 return emitOpError() 844 << "all elements in struct must be same type but there is " << t; 845 } 846 847 if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 && 848 typeD != WGMMATypes::s32) { 849 return emitOpError() << "does not support the given output type " 850 << NVVM::stringifyWGMMATypes(typeD); 851 } 852 if (typeD == WGMMATypes::s32 && 853 (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) { 854 return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg"; 855 } 856 857 if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) { 858 return emitOpError() << NVVM::stringifyWGMMATypes(typeD) 859 << " += " << NVVM::stringifyWGMMATypes(typeA) << " * " 860 << NVVM::stringifyWGMMATypes(typeB) 861 << ", it is not supported."; 862 } 863 864 // Check M 865 if (getShape().getM() != 64) 866 return emitOpError() << "shape 'm' must be 64"; 867 868 // Check K 869 FailureOr<int> allowedK = getAllowedSizeK(typeA); 870 if (failed(allowedK) || allowedK.value() != getShape().getK()) 871 return emitOpError() << "shape 'k' must be " << allowedK.value() 872 << " for input type " 873 << NVVM::stringifyWGMMATypes(typeA); 874 875 // Check N 876 if (failed(isAllowedSizeN(getShape().getN(), typeA))) { 877 return emitOpError() << "has input type " 878 << NVVM::stringifyWGMMATypes(typeA) << " n is set to " 879 << getShape().getN() << ", it is not supported."; 880 } 881 882 // Check transpose (only available for f16/bf16) 883 if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) && 884 (getLayoutA() == mlir::NVVM::MMALayout::col || 885 getLayoutB() == mlir::NVVM::MMALayout::col)) { 886 return emitOpError() 887 << "given layouts layout_a = " << stringifyMMALayout(getLayoutA()) 888 << " and layout_b = " << stringifyMMALayout(getLayoutB()) 889 << " for input types " << stringifyWGMMATypes(typeA) << " and " 890 << stringifyWGMMATypes(typeB) 891 << " requires transpose. However, this is only supported for: " 892 << stringifyMMATypes(MMATypes::f16) << " and " 893 << stringifyMMATypes(MMATypes::bf16); 894 } 895 896 // Check result registers 897 int expectedOutput = 0; 898 if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32) 899 expectedOutput = getShape().getN() / 2; 900 if (typeD == WGMMATypes::f16) 901 expectedOutput = getShape().getN() / 4; 902 if (outputSize != expectedOutput) { 903 return emitOpError() << "results " << expectedOutput 904 << ", however output struct has " << outputSize 905 << " elements"; 906 } 907 // Check satfinite (only available for s32 accumulator) 908 if (typeD != WGMMATypes::s32 && 909 getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == 910 NVVM::MMAIntOverflow::satfinite) { 911 return emitOpError() 912 << " `satfinite` can be only used with s32 accumulator, however " 913 "the current accumulator is " 914 << NVVM::stringifyWGMMATypes(typeD); 915 } 916 917 return success(); 918 } 919 920 std::string NVVM::WgmmaMmaAsyncOp::getPtx() { 921 922 int m = getShape().getM(), n = getShape().getN(), k = getShape().getK(); 923 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; 924 925 StringRef outputTypeName = stringifyWGMMATypes(getTypeD()); 926 927 int expectedOutputRegisters = 0; 928 if (getTypeD() == WGMMATypes::f16) 929 expectedOutputRegisters = getShape().getN() / 4; 930 else 931 expectedOutputRegisters = getShape().getN() / 2; 932 933 std::string ptx; 934 llvm::raw_string_ostream ss(ptx); 935 936 ss << "{\n" 937 ".reg .pred p;\n" 938 "setp.ne.b32 p, $" 939 << ((expectedOutputRegisters * 2) + 2) 940 << ", 0;\n" 941 "wgmma.mma_async.sync.aligned.m" 942 << m << "n" << n << "k" << k << "." << outputTypeName << "." 943 << stringifyWGMMATypes(getTypeA()) << "." 944 << stringifyWGMMATypes(getTypeB()); 945 if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) == 946 NVVM::MMAIntOverflow::satfinite) 947 ss << ".satfinite"; 948 ss << " {"; 949 int regCnt = 0; 950 for (; regCnt < expectedOutputRegisters; ++regCnt) { 951 ss << "$" << regCnt; 952 if (regCnt != expectedOutputRegisters - 1) 953 ss << ", "; 954 } 955 956 ss << "},"; 957 // Need to map read/write registers correctly. 958 regCnt = (regCnt * 2); 959 ss << " $" << (regCnt) << "," 960 << " $" << (regCnt + 1) << "," 961 << " p"; 962 if (getTypeD() != WGMMATypes::s32) { 963 ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4); 964 } 965 // Don't add transpose parameters unless needed. 966 if (isF16) { 967 ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6); 968 } 969 ss << ";\n" 970 << "}\n"; 971 ss.flush(); 972 return ptx; 973 } 974 975 void NVVM::WgmmaMmaAsyncOp::getAsmValues( 976 RewriterBase &rewriter, 977 llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> 978 &asmValues) { 979 bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16; 980 if (getResults()) 981 asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write}); 982 if (getInouts()) 983 asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite}); 984 asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read}); 985 asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read}); 986 asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())), 987 mlir::NVVM::PTXRegisterMod::Read}); 988 if (getTypeD() != WGMMATypes::s32) { 989 asmValues.push_back( 990 {makeConstantI32(rewriter, 991 getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1), 992 mlir::NVVM::PTXRegisterMod::Read}); 993 asmValues.push_back( 994 {makeConstantI32(rewriter, 995 getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1), 996 mlir::NVVM::PTXRegisterMod::Read}); 997 } 998 if (isF16) { 999 asmValues.push_back( 1000 {makeConstantI32(rewriter, static_cast<int>(getLayoutA())), 1001 mlir::NVVM::PTXRegisterMod::Read}); 1002 asmValues.push_back( 1003 {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())), 1004 mlir::NVVM::PTXRegisterMod::Read}); 1005 } 1006 } 1007 LogicalResult NVVM::FenceProxyOp::verify() { 1008 if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) { 1009 return emitOpError() << "async_shared fence requires space attribute"; 1010 } 1011 if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) { 1012 return emitOpError() << "only async_shared fence can have space attribute"; 1013 } 1014 return success(); 1015 } 1016 1017 LogicalResult NVVM::SetMaxRegisterOp::verify() { 1018 if (getRegCount() % 8) 1019 return emitOpError("new register size must be multiple of 8"); 1020 if (getRegCount() < 24 || getRegCount() > 256) 1021 return emitOpError("new register size must be in between 24 to 256"); 1022 return success(); 1023 } 1024 1025 //===----------------------------------------------------------------------===// 1026 // NVVMDialect initialization, type parsing, and registration. 1027 //===----------------------------------------------------------------------===// 1028 1029 // TODO: This should be the llvm.nvvm dialect once this is supported. 1030 void NVVMDialect::initialize() { 1031 addOperations< 1032 #define GET_OP_LIST 1033 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" 1034 >(); 1035 addAttributes< 1036 #define GET_ATTRDEF_LIST 1037 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" 1038 >(); 1039 1040 // Support unknown operations because not all NVVM operations are 1041 // registered. 1042 allowUnknownOperations(); 1043 declarePromisedInterface<NVVMDialect, ConvertToLLVMPatternInterface>(); 1044 declarePromisedInterface<NVVMTargetAttr, gpu::TargetAttrInterface>(); 1045 } 1046 1047 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, 1048 NamedAttribute attr) { 1049 StringAttr attrName = attr.getName(); 1050 // Kernel function attribute should be attached to functions. 1051 if (attrName == NVVMDialect::getKernelFuncAttrName()) { 1052 if (!isa<LLVM::LLVMFuncOp>(op)) { 1053 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() 1054 << "' attribute attached to unexpected op"; 1055 } 1056 } 1057 // If maxntid and reqntid exist, it must be an array with max 3 dim 1058 if (attrName == NVVMDialect::getMaxntidAttrName() || 1059 attrName == NVVMDialect::getReqntidAttrName()) { 1060 auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue()); 1061 if (!values || values.empty() || values.size() > 3) 1062 return op->emitError() 1063 << "'" << attrName 1064 << "' attribute must be integer array with maximum 3 index"; 1065 } 1066 // If minctasm and maxnreg exist, it must be an integer attribute 1067 if (attrName == NVVMDialect::getMinctasmAttrName() || 1068 attrName == NVVMDialect::getMaxnregAttrName()) { 1069 if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) 1070 return op->emitError() 1071 << "'" << attrName << "' attribute must be integer constant"; 1072 } 1073 1074 return success(); 1075 } 1076 1077 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op, 1078 unsigned regionIndex, 1079 unsigned argIndex, 1080 NamedAttribute argAttr) { 1081 auto funcOp = dyn_cast<FunctionOpInterface>(op); 1082 if (!funcOp) 1083 return success(); 1084 1085 bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName()); 1086 StringAttr attrName = argAttr.getName(); 1087 if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) { 1088 if (!isKernel) { 1089 return op->emitError() 1090 << "'" << attrName 1091 << "' attribute must be present only on kernel arguments"; 1092 } 1093 if (!isa<UnitAttr>(argAttr.getValue())) 1094 return op->emitError() << "'" << attrName << "' must be a unit attribute"; 1095 if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) { 1096 return op->emitError() 1097 << "'" << attrName 1098 << "' attribute requires the argument to also have attribute '" 1099 << LLVM::LLVMDialect::getByValAttrName() << "'"; 1100 } 1101 } 1102 1103 return success(); 1104 } 1105 1106 //===----------------------------------------------------------------------===// 1107 // NVVM target attribute. 1108 //===----------------------------------------------------------------------===// 1109 LogicalResult 1110 NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, 1111 int optLevel, StringRef triple, StringRef chip, 1112 StringRef features, DictionaryAttr flags, 1113 ArrayAttr files) { 1114 if (optLevel < 0 || optLevel > 3) { 1115 emitError() << "The optimization level must be a number between 0 and 3."; 1116 return failure(); 1117 } 1118 if (triple.empty()) { 1119 emitError() << "The target triple cannot be empty."; 1120 return failure(); 1121 } 1122 if (chip.empty()) { 1123 emitError() << "The target chip cannot be empty."; 1124 return failure(); 1125 } 1126 if (files && !llvm::all_of(files, [](::mlir::Attribute attr) { 1127 return attr && mlir::isa<StringAttr>(attr); 1128 })) { 1129 emitError() << "All the elements in the `link` array must be strings."; 1130 return failure(); 1131 } 1132 return success(); 1133 } 1134 1135 #define GET_OP_CLASSES 1136 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" 1137 1138 #define GET_ATTRDEF_CLASSES 1139 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" 1140