xref: /llvm-project/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18 
19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
21 #include "mlir/Dialect/Utils/StaticValueUtils.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
26 #include "mlir/IR/DialectImplementation.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
29 #include "mlir/IR/OperationSupport.h"
30 #include "mlir/IR/Types.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/AsmParser/Parser.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/SourceMgr.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include <cassert>
41 #include <optional>
42 #include <string>
43 
44 using namespace mlir;
45 using namespace NVVM;
46 
47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
49 
50 //===----------------------------------------------------------------------===//
51 // Printing/parsing for NVVM ops
52 //===----------------------------------------------------------------------===//
53 
54 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
55   p << " " << op->getOperands();
56   if (op->getNumResults() > 0)
57     p << " : " << op->getResultTypes();
58 }
59 
60 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
61 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
62   MLIRContext *context = parser.getContext();
63   auto int32Ty = IntegerType::get(context, 32);
64   auto int1Ty = IntegerType::get(context, 1);
65 
66   SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
67   Type type;
68   return failure(parser.parseOperandList(ops) ||
69                  parser.parseOptionalAttrDict(result.attributes) ||
70                  parser.parseColonType(type) ||
71                  parser.addTypeToList(type, result.types) ||
72                  parser.resolveOperands(ops, {int32Ty, int1Ty},
73                                         parser.getNameLoc(), result.operands));
74 }
75 
76 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
77 
78 LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
79   if (getCoordinates().empty() || getCoordinates().size() > 5)
80     return emitError("expects coordinates between 1 to 5 dimension");
81 
82   // Check for im2col mode
83   if (!getIm2colOffsets().empty()) {
84     if (getCoordinates().size() < 3)
85       return emitError(
86           "to use im2col mode, the tensor has to be at least 3-dimensional");
87     if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
88       return emitError(
89           "im2col offsets must be 2 less than number of coordinates");
90   }
91   return success();
92 }
93 
94 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
95   if (getCoordinates().size() > 5)
96     return emitError("Maximum 5 coordinates and dimension is supported.");
97   return success();
98 }
99 
100 LogicalResult CpAsyncOp::verify() {
101   if (getModifier() != LoadCacheModifierKind::CG &&
102       getModifier() != LoadCacheModifierKind::CA)
103     return emitError("Only CG and CA cache modifiers are supported.");
104   if (getSize() != 4 && getSize() != 8 && getSize() != 16)
105     return emitError("expected byte size to be either 4, 8 or 16.");
106   if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
107     return emitError("CG cache modifier is only support for 16 bytes copy.");
108   return success();
109 }
110 
111 // Given the element type of an operand and whether or not it is an accumulator,
112 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
113 // operand's element type.
114 std::optional<mlir::NVVM::MMATypes>
115 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
116   auto half2Type =
117       LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
118   if (operandElType.isF64())
119     return NVVM::MMATypes::f64;
120   if (operandElType.isF16() || operandElType == half2Type)
121     return NVVM::MMATypes::f16;
122   if (operandElType.isF32() && isAccumulator)
123     return NVVM::MMATypes::f32;
124   if (operandElType.isF32() && !isAccumulator)
125     return NVVM::MMATypes::tf32;
126   if (llvm::isa<IntegerType>(operandElType)) {
127     if (isAccumulator)
128       return NVVM::MMATypes::s32;
129     return std::nullopt;
130   }
131 
132   if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
133     if (structType.getBody().empty())
134       return std::nullopt;
135     return inferOperandMMAType(structType.getBody()[0], isAccumulator);
136   }
137 
138   return std::nullopt;
139 }
140 
141 static bool isInt4PtxType(MMATypes type) {
142   return (type == MMATypes::u4 || type == MMATypes::s4);
143 }
144 
145 static bool isInt8PtxType(MMATypes type) {
146   return (type == MMATypes::u8 || type == MMATypes::s8);
147 }
148 
149 static bool isIntegerPtxType(MMATypes type) {
150   return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
151          type == MMATypes::s32;
152 }
153 
154 MMATypes MmaOp::accumPtxType() {
155   std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
156       getODSOperands(2).getTypes().front(), /*isAccum=*/true);
157   assert(val.has_value() && "accumulator PTX type should always be inferrable");
158   return val.value();
159 }
160 
161 MMATypes MmaOp::resultPtxType() {
162   std::optional<mlir::NVVM::MMATypes> val =
163       inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
164   assert(val.has_value() && "result PTX type should always be inferrable");
165   return val.value();
166 }
167 
168 void MmaOp::print(OpAsmPrinter &p) {
169   SmallVector<Type, 4> regTypes;
170   struct OperandFragment {
171     StringRef operandName;
172     StringRef ptxTypeAttr;
173     SmallVector<Value, 4> regs;
174     explicit OperandFragment(StringRef name, StringRef ptxTypeName)
175         : operandName(name), ptxTypeAttr(ptxTypeName) {}
176   };
177 
178   std::array<OperandFragment, 3> frags{
179       OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
180       OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
181       OperandFragment("C", "")};
182   SmallVector<StringRef, 4> ignoreAttrNames{
183       mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
184 
185   for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
186     auto &frag = frags[fragIdx];
187     auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
188     for (auto operandIdx = varOperandSpec.first;
189          operandIdx < varOperandSpec.first + varOperandSpec.second;
190          operandIdx++) {
191       frag.regs.push_back(this->getOperand(operandIdx));
192       if (operandIdx == 0) {
193         regTypes.push_back(this->getOperand(operandIdx).getType());
194       }
195     }
196     std::optional<MMATypes> inferredType =
197         inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
198     if (inferredType)
199       ignoreAttrNames.push_back(frag.ptxTypeAttr);
200   }
201 
202   auto printMmaOperand = [&](const OperandFragment &frag) -> void {
203     p << " " << frag.operandName;
204     p << "[";
205     p.printOperands(frag.regs);
206     p << "] ";
207   };
208 
209   for (const auto &frag : frags) {
210     printMmaOperand(frag);
211   }
212 
213   p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
214 
215   // Print the types of the operands and result.
216   p << " : " << "(";
217   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
218                                              frags[1].regs[0].getType(),
219                                              frags[2].regs[0].getType()},
220                         p);
221   p << ")";
222   p.printArrowTypeList(TypeRange{this->getRes().getType()});
223 }
224 
225 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
226                   ValueRange operandA, ValueRange operandB, ValueRange operandC,
227                   ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
228                   std::optional<MMAIntOverflow> intOverflow,
229                   std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
230                   std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
231 
232   assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
233   MLIRContext *ctx = builder.getContext();
234   result.addAttribute(
235       "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
236 
237   result.addOperands(operandA);
238   result.addOperands(operandB);
239   result.addOperands(operandC);
240 
241   if (multiplicandPtxTypes) {
242     result.addAttribute("multiplicandAPtxType",
243                         MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
244     result.addAttribute("multiplicandBPtxType",
245                         MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
246   } else {
247     if (auto res = inferOperandMMAType(operandA[0].getType(), false))
248       result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
249     if (auto res = inferOperandMMAType(operandB[0].getType(), false))
250       result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
251   }
252 
253   if (multiplicandLayouts) {
254     result.addAttribute("layoutA",
255                         MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
256     result.addAttribute("layoutB",
257                         MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
258   } else {
259     result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
260     result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
261   }
262 
263   if (intOverflow.has_value())
264     result.addAttribute("intOverflowBehavior",
265                         MMAIntOverflowAttr::get(ctx, *intOverflow));
266   if (b1Op.has_value())
267     result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
268 
269   result.addTypes(resultType);
270   result.addAttribute(
271       MmaOp::getOperandSegmentSizeAttr(),
272       builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
273                                     static_cast<int32_t>(operandB.size()),
274                                     static_cast<int32_t>(operandC.size())}));
275 }
276 
277 // <operation> :=
278 //   A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
279 //   attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
280 //     `->` type($res)
281 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
282   struct OperandFragment {
283     std::optional<MMATypes> elemtype;
284     SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
285     SmallVector<Type> regTypes;
286   };
287 
288   Builder &builder = parser.getBuilder();
289   std::array<OperandFragment, 4> frags;
290 
291   NamedAttrList namedAttributes;
292 
293   // A helper to parse the operand segments.
294   auto parseMmaOperand = [&](StringRef operandName,
295                              OperandFragment &frag) -> LogicalResult {
296     if (parser.parseKeyword(operandName).failed())
297       return failure();
298     if (parser
299             .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
300             .failed())
301       return failure();
302     return success();
303   };
304 
305   // Parse the operand segments.
306   if (parseMmaOperand("A", frags[0]).failed())
307     return failure();
308   if (parseMmaOperand("B", frags[1]).failed())
309     return failure();
310   if (parseMmaOperand("C", frags[2]).failed())
311     return failure();
312 
313   if (parser.parseOptionalAttrDict(namedAttributes).failed())
314     return failure();
315 
316   // Parse the type specification and resolve operands.
317   SmallVector<Type, 3> operandTypes;
318   if (failed(parser.parseColon()))
319     return failure();
320   if (failed(parser.parseLParen()))
321     return failure();
322   if (failed(parser.parseTypeList(operandTypes)))
323     return failure();
324   if (failed(parser.parseRParen()))
325     if (operandTypes.size() != 3)
326       return parser.emitError(
327           parser.getNameLoc(),
328           "expected one type for each operand segment but got " +
329               Twine(operandTypes.size()) + " types");
330   for (const auto &iter : llvm::enumerate(operandTypes)) {
331     auto &frag = frags[iter.index()];
332     frag.regTypes.resize(frag.regs.size(), iter.value());
333     if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
334                                       parser.getNameLoc(), result.operands)))
335       return failure();
336     frag.elemtype =
337         inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
338   }
339 
340   Type resultType;
341   if (parser.parseArrow() || parser.parseType(resultType))
342     return failure();
343   frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
344 
345   std::array<StringRef, 2> names{"multiplicandAPtxType",
346                                  "multiplicandBPtxType"};
347   for (unsigned idx = 0; idx < names.size(); idx++) {
348     const auto &frag = frags[idx];
349     std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
350     if (!frag.elemtype.has_value() && !attr.has_value()) {
351       return parser.emitError(
352           parser.getNameLoc(),
353           "attribute " + names[idx] +
354               " is not provided explicitly and cannot be inferred");
355     }
356     if (!attr.has_value())
357       result.addAttribute(
358           names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
359   }
360 
361   result.addTypes(resultType);
362   if (!namedAttributes.empty())
363     result.addAttributes(namedAttributes);
364   result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
365                       builder.getDenseI32ArrayAttr({
366                           static_cast<int32_t>(frags[0].regs.size()),
367                           static_cast<int32_t>(frags[1].regs.size()),
368                           static_cast<int32_t>(frags[2].regs.size()),
369                       }));
370   return success();
371 }
372 
373 LogicalResult MmaOp::verify() {
374   MLIRContext *context = getContext();
375   auto f16Ty = Float16Type::get(context);
376   auto i32Ty = IntegerType::get(context, 32);
377   auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
378   auto f32Ty = Float32Type::get(context);
379   auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
380       context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
381 
382   auto s32x4StructTy =
383       LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
384   auto f32x8StructTy =
385       LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
386   auto f16x2x2StructTy =
387       LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
388   auto f32x4StructTy =
389       LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
390   auto s32x2StructTy =
391       LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
392 
393   std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
394                                   getShapeAttr().getK()};
395 
396   // These variables define the set of allowed data types for matrices A, B, C,
397   // and result.
398   using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
399   using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
400   AllowedShapes allowedShapes;
401   AllowedTypes expectedA;
402   AllowedTypes expectedB;
403   AllowedTypes expectedC;
404   SmallVector<Type> expectedResult;
405 
406   // When M = 16, we just need to calculate the number of 8xk tiles, where
407   // k is a factor that depends on the data type.
408   if (mmaShape[0] == 16) {
409     int64_t kFactor;
410     Type multiplicandFragType;
411     switch (*getMultiplicandAPtxType()) {
412     case MMATypes::tf32:
413       kFactor = 4;
414       multiplicandFragType = i32Ty;
415       expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
416           context, {f32Ty, f32Ty, f32Ty, f32Ty}));
417       break;
418     case MMATypes::f16:
419     case MMATypes::bf16:
420       kFactor = 8;
421       multiplicandFragType = f16x2Ty;
422       expectedResult.push_back(f16x2x2StructTy);
423       expectedResult.push_back(f32x4StructTy);
424       break;
425     case MMATypes::s4:
426     case MMATypes::u4:
427       kFactor = 32;
428       break;
429     case MMATypes::b1:
430       kFactor = 128;
431       break;
432     case MMATypes::s8:
433     case MMATypes::u8:
434       kFactor = 16;
435       break;
436     default:
437       return emitError("invalid shape or multiplicand type: " +
438                        stringifyEnum(getMultiplicandAPtxType().value()));
439     }
440 
441     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
442       expectedResult.push_back(s32x4StructTy);
443       expectedC.emplace_back(4, i32Ty);
444       multiplicandFragType = i32Ty;
445     } else {
446       expectedC.emplace_back(2, f16x2Ty);
447       expectedC.emplace_back(4, f32Ty);
448     }
449 
450     int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
451     int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
452     expectedA.emplace_back(unitA, multiplicandFragType);
453     expectedB.emplace_back(unitB, multiplicandFragType);
454     allowedShapes.push_back({16, 8, kFactor});
455     allowedShapes.push_back({16, 8, kFactor * 2});
456   }
457 
458   // In the M=8 case, there is only 1 possible case per data type.
459   if (mmaShape[0] == 8) {
460     if (*getMultiplicandAPtxType() == MMATypes::f16) {
461       expectedA.emplace_back(2, f16x2Ty);
462       expectedB.emplace_back(2, f16x2Ty);
463       expectedResult.push_back(f16x2x4StructTy);
464       expectedResult.push_back(f32x8StructTy);
465       expectedC.emplace_back(4, f16x2Ty);
466       expectedC.emplace_back(8, f32Ty);
467       allowedShapes.push_back({8, 8, 4});
468     }
469     if (*getMultiplicandAPtxType() == MMATypes::f64) {
470       Type f64Ty = Float64Type::get(context);
471       expectedA.emplace_back(1, f64Ty);
472       expectedB.emplace_back(1, f64Ty);
473       expectedC.emplace_back(2, f64Ty);
474       // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
475       expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
476           context, SmallVector<Type>(2, f64Ty)));
477       allowedShapes.push_back({8, 8, 4});
478     }
479     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
480       expectedA.push_back({i32Ty});
481       expectedB.push_back({i32Ty});
482       expectedC.push_back({i32Ty, i32Ty});
483       expectedResult.push_back(s32x2StructTy);
484       if (isInt4PtxType(getMultiplicandAPtxType().value()))
485         allowedShapes.push_back({8, 8, 32});
486       if (isInt8PtxType(getMultiplicandAPtxType().value()))
487         allowedShapes.push_back({8, 8, 16});
488       if (getMultiplicandAPtxType().value() == MMATypes::b1)
489         allowedShapes.push_back({8, 8, 128});
490     }
491   }
492 
493   std::string errorMessage;
494   llvm::raw_string_ostream errorStream(errorMessage);
495 
496   // Check that we matched an existing shape/dtype combination.
497   if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
498       !llvm::is_contained(allowedShapes, mmaShape)) {
499     errorStream << "unimplemented variant for MMA shape <";
500     llvm::interleaveComma(mmaShape, errorStream);
501     errorStream << ">";
502     return emitOpError(errorMessage);
503   }
504 
505   // Verify the operand types for segments of A, B, and C operands.
506   std::array<StringRef, 3> operandNames{"A", "B", "C"};
507   for (const auto &iter : llvm::enumerate(
508            SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
509     auto spec = this->getODSOperandIndexAndLength(iter.index());
510     SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
511                                       operand_type_begin() + spec.first +
512                                           spec.second);
513     bool match = llvm::is_contained(iter.value(), operandTySeg);
514 
515     if (!match) {
516       errorStream << "Could not match types for the "
517                   << operandNames[iter.index()]
518                   << " operands; expected one of ";
519       for (const auto &x : iter.value()) {
520         errorStream << x.size() << "x" << x[0] << " ";
521       }
522       errorStream << "but got ";
523       llvm::interleaveComma(operandTySeg, errorStream);
524       return emitOpError(errorStream.str());
525     }
526   }
527 
528   // Check the result type
529   if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
530         return expectedResultType == getResult().getType();
531       })) {
532     errorStream
533         << "Could not match allowed types for the result; expected one of ";
534     llvm::interleaveComma(expectedResult, errorStream);
535     errorStream << " but got " << getResult().getType();
536     return emitOpError(errorStream.str());
537   }
538 
539   // Ensure that binary MMA variants have a b1 MMA operation defined.
540   if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
541     return emitOpError("op requires " + getB1OpAttrName().strref() +
542                        " attribute");
543   }
544 
545   // Ensure int4/int8 MMA variants specify the accum overflow behavior
546   // attribute.
547   if (isInt4PtxType(*getMultiplicandAPtxType()) ||
548       isInt8PtxType(*getMultiplicandAPtxType())) {
549     if (!getIntOverflowBehavior())
550       return emitOpError("op requires " +
551                          getIntOverflowBehaviorAttrName().strref() +
552                          " attribute");
553   }
554 
555   return success();
556 }
557 
558 LogicalResult ShflOp::verify() {
559   if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
560     return success();
561   auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
562   auto elementType = (type && type.getBody().size() == 2)
563                          ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
564                          : nullptr;
565   if (!elementType || elementType.getWidth() != 1)
566     return emitError("expected return type to be a two-element struct with "
567                      "i1 as the second element");
568   return success();
569 }
570 
571 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
572                                                    NVVM::MMAFrag frag, int nRow,
573                                                    int nCol,
574                                                    MLIRContext *context) {
575   unsigned numberElements = 0;
576   Type elementType;
577   OpBuilder builder(context);
578   Type f16x2 = VectorType::get(2, builder.getF16Type());
579   if (type == NVVM::MMATypes::f16) {
580     elementType = f16x2;
581     if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
582       numberElements = 8;
583     else
584       numberElements = 4;
585   } else if (type == NVVM::MMATypes::f32) {
586     elementType = builder.getF32Type();
587     numberElements = 8;
588   } else if (type == NVVM::MMATypes::tf32) {
589     elementType = builder.getI32Type();
590     numberElements = 4;
591   } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
592     elementType = builder.getI32Type();
593     int parallelSize = 0;
594     if (frag == NVVM::MMAFrag::a)
595       parallelSize = nRow;
596     if (frag == NVVM::MMAFrag::b)
597       parallelSize = nCol;
598 
599     // m == 16 && n == 16 && k == 16
600     if (parallelSize == 16)
601       numberElements = 2;
602     // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
603     else if (parallelSize == 8)
604       numberElements = 1;
605     else if (parallelSize == 32)
606       numberElements = 4;
607   } else if (type == NVVM::MMATypes::s32) {
608     elementType = builder.getI32Type();
609     numberElements = 8;
610   }
611   assert(numberElements != 0 && elementType != nullptr);
612   return std::make_pair(elementType, numberElements);
613 }
614 
615 static std::pair<mlir::Type, unsigned>
616 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
617                     int k, MLIRContext *context) {
618   int nRow, nCol;
619   if (frag == NVVM::MMAFrag::a) {
620     nRow = m;
621     nCol = k;
622   } else if (frag == NVVM::MMAFrag::b) {
623     nRow = k;
624     nCol = n;
625   } else {
626     nRow = m;
627     nCol = n;
628   }
629   assert(nRow && nCol);
630   return inferMMAType(type, frag, nRow, nCol, context);
631 }
632 
633 LogicalResult NVVM::WMMALoadOp::verify() {
634   unsigned addressSpace =
635       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
636   if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
637       addressSpace != NVVM::kSharedMemorySpace)
638     return emitOpError("expected source pointer in memory "
639                        "space 0, 1, 3");
640 
641   if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
642                                        getEltype(), getFrag()) == 0)
643     return emitOpError() << "invalid attribute combination";
644   std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
645       getEltype(), getFrag(), getM(), getN(), getK(), getContext());
646   Type dstType = LLVM::LLVMStructType::getLiteral(
647       getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
648   if (getType() != dstType)
649     return emitOpError("expected destination type is a structure of ")
650            << typeInfo.second << " elements of type " << typeInfo.first;
651   return success();
652 }
653 
654 LogicalResult NVVM::WMMAStoreOp::verify() {
655   unsigned addressSpace =
656       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
657   if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
658       addressSpace != NVVM::kSharedMemorySpace)
659     return emitOpError("expected operands to be a source pointer in memory "
660                        "space 0, 1, 3");
661 
662   if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
663                                         getEltype()) == 0)
664     return emitOpError() << "invalid attribute combination";
665   std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
666       getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
667   if (getArgs().size() != typeInfo.second)
668     return emitOpError() << "expected " << typeInfo.second << " data operands";
669   if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
670         return operands.getType() != typeInfo.first;
671       }))
672     return emitOpError() << "expected data operands of type " << typeInfo.first;
673   return success();
674 }
675 
676 LogicalResult NVVM::WMMAMmaOp::verify() {
677   if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
678                                       getLayoutB(), getEltypeA(),
679                                       getEltypeB()) == 0)
680     return emitOpError() << "invalid attribute combination";
681   std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
682       getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
683   std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
684       getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
685   std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
686       getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
687   SmallVector<Type, 32> arguments;
688   arguments.append(typeInfoA.second, typeInfoA.first);
689   arguments.append(typeInfoB.second, typeInfoB.first);
690   arguments.append(typeInfoC.second, typeInfoC.first);
691   unsigned numArgs = arguments.size();
692   if (getArgs().size() != numArgs)
693     return emitOpError() << "expected " << numArgs << " arguments";
694   for (unsigned i = 0; i < numArgs; i++) {
695     if (getArgs()[i].getType() != arguments[i])
696       return emitOpError() << "expected argument " << i << " to be of type "
697                            << arguments[i];
698   }
699   Type dstType = LLVM::LLVMStructType::getLiteral(
700       getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
701   if (getType() != dstType)
702     return emitOpError("expected destination type is a structure of ")
703            << typeInfoC.second << " elements of type " << typeInfoC.first;
704   return success();
705 }
706 
707 LogicalResult NVVM::LdMatrixOp::verify() {
708   unsigned addressSpace =
709       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
710   if (addressSpace != NVVM::kSharedMemorySpace)
711     return emitOpError("expected source pointer in memory space 3");
712 
713   if (getNum() != 1 && getNum() != 2 && getNum() != 4)
714     return emitOpError("expected num attribute to be 1, 2 or 4");
715 
716   Type i32 = IntegerType::get(getContext(), 32);
717   if (getNum() == 1 && getType() != i32)
718     return emitOpError("expected destination type is i32");
719   if (getNum() == 2 || getNum() == 4) {
720     Type dstType = LLVM::LLVMStructType::getLiteral(
721         getContext(), SmallVector<Type>(getNum(), i32));
722     if (getType() != dstType)
723       return emitOpError("expected destination type is a structure of ")
724              << getNum() << " elements of type i32";
725   }
726   return success();
727 }
728 
729 LogicalResult NVVM::StMatrixOp::verify() {
730   unsigned addressSpace =
731       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
732   if (addressSpace != NVVM::kSharedMemorySpace)
733     return emitOpError("expected source pointer in memory space 3");
734 
735   int numMatrix = getSources().size();
736   if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
737     return emitOpError("expected num attribute to be 1, 2 or 4");
738 
739   return success();
740 }
741 
742 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
743   if (typeA == NVVM::WGMMATypes::tf32)
744     return 8;
745   if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
746     return 16;
747   if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
748     return 32;
749   if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
750     return 32;
751   if (typeA == NVVM::WGMMATypes::b1)
752     return 256;
753   return failure();
754 }
755 
756 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
757                                      NVVM::WGMMATypes typeA,
758                                      NVVM::WGMMATypes typeB) {
759   switch (typeA) {
760   case NVVM::WGMMATypes::f16:
761     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
762         typeB == NVVM::WGMMATypes::f16)
763       return success();
764     break;
765   case NVVM::WGMMATypes::tf32:
766     if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
767       return success();
768     break;
769   case NVVM::WGMMATypes::u8:
770   case NVVM::WGMMATypes::s8:
771     if (typeD == NVVM::WGMMATypes::s32 &&
772         (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
773       return success();
774     break;
775   case NVVM::WGMMATypes::b1:
776     if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
777       return success();
778     break;
779   case NVVM::WGMMATypes::bf16:
780     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
781         typeB == NVVM::WGMMATypes::bf16)
782       return success();
783     break;
784   case NVVM::WGMMATypes::e4m3:
785   case NVVM::WGMMATypes::e5m2:
786     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
787         (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
788       return success();
789     break;
790   case WGMMATypes::f32:
791   case WGMMATypes::s32:
792     llvm_unreachable("unsupported input types");
793     break;
794   }
795   return failure();
796 }
797 
798 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
799   SmallVector<int> allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
800                                72,  80,  88,  96,  104, 112, 120, 128,
801                                136, 144, 152, 160, 168, 176, 184, 192,
802                                200, 208, 216, 224, 232, 240, 248, 256};
803   SmallVector<int> allowedNshort = {8,   16,  24,  32,  48,  64,
804                                     80,  96,  112, 128, 144, 160,
805                                     176, 192, 208, 224, 240, 256};
806   switch (typeA) {
807   case WGMMATypes::f16:
808   case WGMMATypes::tf32:
809   case WGMMATypes::bf16:
810   case WGMMATypes::e4m3:
811   case WGMMATypes::e5m2:
812     if (llvm::is_contained(allowedN, sizeN))
813       return success();
814     break;
815   case WGMMATypes::u8:
816   case WGMMATypes::s8:
817   case WGMMATypes::b1:
818     if (llvm::is_contained(allowedNshort, sizeN))
819       return success();
820     break;
821   case WGMMATypes::f32:
822   case WGMMATypes::s32:
823     llvm_unreachable("unsupported input types");
824     break;
825   }
826   return failure();
827 }
828 
829 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
830   Value outValue = getResults();
831   auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
832   if (!stype)
833     return emitOpError() << "expected results to be struct";
834   int outputSize = stype.getBody().size();
835   WGMMATypes typeD = getTypeD();
836   WGMMATypes typeA = getTypeA();
837   WGMMATypes typeB = getTypeB();
838 
839   for (Type t : stype.getBody()) {
840     if (t != stype.getBody().front())
841       return emitOpError()
842              << "all elements in struct must be same type but there is " << t;
843   }
844 
845   if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
846       typeD != WGMMATypes::s32) {
847     return emitOpError() << "does not support the given output type "
848                          << NVVM::stringifyWGMMATypes(typeD);
849   }
850   if (typeD == WGMMATypes::s32 &&
851       (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
852     return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
853   }
854 
855   if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
856     return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
857                          << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
858                          << NVVM::stringifyWGMMATypes(typeB)
859                          << ", it is not supported.";
860   }
861 
862   // Check M
863   if (getShape().getM() != 64)
864     return emitOpError() << "shape 'm' must be 64";
865 
866   // Check K
867   FailureOr<int> allowedK = getAllowedSizeK(typeA);
868   if (failed(allowedK) || allowedK.value() != getShape().getK())
869     return emitOpError() << "shape 'k' must be " << allowedK.value()
870                          << " for input type "
871                          << NVVM::stringifyWGMMATypes(typeA);
872 
873   // Check N
874   if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
875     return emitOpError() << "has input type "
876                          << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
877                          << getShape().getN() << ", it is not supported.";
878   }
879 
880   // Check transpose (only available for f16/bf16)
881   if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
882       (getLayoutA() == mlir::NVVM::MMALayout::col ||
883        getLayoutB() == mlir::NVVM::MMALayout::col)) {
884     return emitOpError()
885            << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
886            << " and layout_b = " << stringifyMMALayout(getLayoutB())
887            << " for input types " << stringifyWGMMATypes(typeA) << " and "
888            << stringifyWGMMATypes(typeB)
889            << " requires transpose. However, this is only supported for: "
890            << stringifyMMATypes(MMATypes::f16) << " and "
891            << stringifyMMATypes(MMATypes::bf16);
892   }
893 
894   // Check result registers
895   int expectedOutput = 0;
896   if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
897     expectedOutput = getShape().getN() / 2;
898   if (typeD == WGMMATypes::f16)
899     expectedOutput = getShape().getN() / 4;
900   if (outputSize != expectedOutput) {
901     return emitOpError() << "results " << expectedOutput
902                          << ", however output struct has " << outputSize
903                          << " elements";
904   }
905   // Check satfinite (only available for s32 accumulator)
906   if (typeD != WGMMATypes::s32 &&
907       getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
908           NVVM::MMAIntOverflow::satfinite) {
909     return emitOpError()
910            << " `satfinite` can be only used with s32 accumulator, however "
911               "the current accumulator is "
912            << NVVM::stringifyWGMMATypes(typeD);
913   }
914 
915   return success();
916 }
917 
918 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
919 
920   int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
921   bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
922 
923   StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
924 
925   int expectedOutputRegisters = 0;
926   if (getTypeD() == WGMMATypes::f16)
927     expectedOutputRegisters = getShape().getN() / 4;
928   else
929     expectedOutputRegisters = getShape().getN() / 2;
930 
931   std::string ptx;
932   llvm::raw_string_ostream ss(ptx);
933 
934   ss << "{\n"
935         ".reg .pred p;\n"
936         "setp.ne.b32 p, $"
937      << ((expectedOutputRegisters * 2) + 2)
938      << ", 0;\n"
939         "wgmma.mma_async.sync.aligned.m"
940      << m << "n" << n << "k" << k << "." << outputTypeName << "."
941      << stringifyWGMMATypes(getTypeA()) << "."
942      << stringifyWGMMATypes(getTypeB());
943   if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
944       NVVM::MMAIntOverflow::satfinite)
945     ss << ".satfinite";
946   ss << " {";
947   int regCnt = 0;
948   for (; regCnt < expectedOutputRegisters; ++regCnt) {
949     ss << "$" << regCnt;
950     if (regCnt != expectedOutputRegisters - 1)
951       ss << ", ";
952   }
953 
954   ss << "},";
955   // Need to map read/write registers correctly.
956   regCnt = (regCnt * 2);
957   ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
958   if (getTypeD() != WGMMATypes::s32) {
959     ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
960   }
961   // Don't add transpose parameters unless needed.
962   if (isF16) {
963     ss << ", $" << (regCnt + 5) << ",  $" << (regCnt + 6);
964   }
965   ss << ";\n"
966      << "}\n";
967   ss.flush();
968   return ptx;
969 }
970 
971 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
972     RewriterBase &rewriter,
973     llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
974         &asmValues) {
975   bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
976   if (getResults())
977     asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
978   if (getInouts())
979     asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
980   asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
981   asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
982   asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
983                        mlir::NVVM::PTXRegisterMod::Read});
984   if (getTypeD() != WGMMATypes::s32) {
985     asmValues.push_back(
986         {makeConstantI32(rewriter,
987                          getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
988          mlir::NVVM::PTXRegisterMod::Read});
989     asmValues.push_back(
990         {makeConstantI32(rewriter,
991                          getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
992          mlir::NVVM::PTXRegisterMod::Read});
993   }
994   if (isF16) {
995     asmValues.push_back(
996         {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
997          mlir::NVVM::PTXRegisterMod::Read});
998     asmValues.push_back(
999         {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1000          mlir::NVVM::PTXRegisterMod::Read});
1001   }
1002 }
1003 LogicalResult NVVM::FenceProxyOp::verify() {
1004   if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1005     return emitOpError() << "async_shared fence requires space attribute";
1006   }
1007   if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1008     return emitOpError() << "only async_shared fence can have space attribute";
1009   }
1010   return success();
1011 }
1012 
1013 LogicalResult NVVM::SetMaxRegisterOp::verify() {
1014   if (getRegCount() % 8)
1015     return emitOpError("new register size must be multiple of 8");
1016   if (getRegCount() < 24 || getRegCount() > 256)
1017     return emitOpError("new register size must be in between 24 to 256");
1018   return success();
1019 }
1020 
1021 LogicalResult NVVM::BarrierOp::verify() {
1022   if (getNumberOfThreads() && !getBarrierId())
1023     return emitOpError(
1024         "barrier id is missing, it should be set between 0 to 15");
1025   return success();
1026 }
1027 
1028 //===----------------------------------------------------------------------===//
1029 // NVVMDialect initialization, type parsing, and registration.
1030 //===----------------------------------------------------------------------===//
1031 
1032 // TODO: This should be the llvm.nvvm dialect once this is supported.
1033 void NVVMDialect::initialize() {
1034   addOperations<
1035 #define GET_OP_LIST
1036 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1037       >();
1038   addAttributes<
1039 #define GET_ATTRDEF_LIST
1040 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1041       >();
1042 
1043   // Support unknown operations because not all NVVM operations are
1044   // registered.
1045   allowUnknownOperations();
1046   declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1047   declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1048 }
1049 
1050 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1051                                                     NamedAttribute attr) {
1052   StringAttr attrName = attr.getName();
1053   // Kernel function attribute should be attached to functions.
1054   if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1055     if (!isa<LLVM::LLVMFuncOp>(op)) {
1056       return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1057                              << "' attribute attached to unexpected op";
1058     }
1059   }
1060   // If maxntid and reqntid exist, it must be an array with max 3 dim
1061   if (attrName == NVVMDialect::getMaxntidAttrName() ||
1062       attrName == NVVMDialect::getReqntidAttrName()) {
1063     auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1064     if (!values || values.empty() || values.size() > 3)
1065       return op->emitError()
1066              << "'" << attrName
1067              << "' attribute must be integer array with maximum 3 index";
1068   }
1069   // If minctasm and maxnreg exist, it must be an integer attribute
1070   if (attrName == NVVMDialect::getMinctasmAttrName() ||
1071       attrName == NVVMDialect::getMaxnregAttrName()) {
1072     if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1073       return op->emitError()
1074              << "'" << attrName << "' attribute must be integer constant";
1075   }
1076 
1077   return success();
1078 }
1079 
1080 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1081                                                     unsigned regionIndex,
1082                                                     unsigned argIndex,
1083                                                     NamedAttribute argAttr) {
1084   auto funcOp = dyn_cast<FunctionOpInterface>(op);
1085   if (!funcOp)
1086     return success();
1087 
1088   bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1089   StringAttr attrName = argAttr.getName();
1090   if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1091     if (!isKernel) {
1092       return op->emitError()
1093              << "'" << attrName
1094              << "' attribute must be present only on kernel arguments";
1095     }
1096     if (!isa<UnitAttr>(argAttr.getValue()))
1097       return op->emitError() << "'" << attrName << "' must be a unit attribute";
1098     if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1099       return op->emitError()
1100              << "'" << attrName
1101              << "' attribute requires the argument to also have attribute '"
1102              << LLVM::LLVMDialect::getByValAttrName() << "'";
1103     }
1104   }
1105 
1106   return success();
1107 }
1108 
1109 //===----------------------------------------------------------------------===//
1110 // NVVM target attribute.
1111 //===----------------------------------------------------------------------===//
1112 LogicalResult
1113 NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1114                        int optLevel, StringRef triple, StringRef chip,
1115                        StringRef features, DictionaryAttr flags,
1116                        ArrayAttr files) {
1117   if (optLevel < 0 || optLevel > 3) {
1118     emitError() << "The optimization level must be a number between 0 and 3.";
1119     return failure();
1120   }
1121   if (triple.empty()) {
1122     emitError() << "The target triple cannot be empty.";
1123     return failure();
1124   }
1125   if (chip.empty()) {
1126     emitError() << "The target chip cannot be empty.";
1127     return failure();
1128   }
1129   if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1130         return attr && mlir::isa<StringAttr>(attr);
1131       })) {
1132     emitError() << "All the elements in the `link` array must be strings.";
1133     return failure();
1134   }
1135   return success();
1136 }
1137 
1138 #define GET_OP_CLASSES
1139 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1140 
1141 #define GET_ATTRDEF_CLASSES
1142 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1143