xref: /llvm-project/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (revision 346c4a88afedcef3da40f68c83f0a5b3e0ac61ea)
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18 
19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
21 #include "mlir/Dialect/Utils/StaticValueUtils.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
26 #include "mlir/IR/DialectImplementation.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
29 #include "mlir/IR/OperationSupport.h"
30 #include "mlir/IR/Types.h"
31 #include "mlir/Support/LogicalResult.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/AsmParser/Parser.h"
35 #include "llvm/IR/Attributes.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/SourceMgr.h"
40 #include "llvm/Support/raw_ostream.h"
41 #include <cassert>
42 #include <optional>
43 #include <string>
44 
45 using namespace mlir;
46 using namespace NVVM;
47 
48 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
49 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
50 
51 //===----------------------------------------------------------------------===//
52 // Printing/parsing for NVVM ops
53 //===----------------------------------------------------------------------===//
54 
55 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
56   p << " " << op->getOperands();
57   if (op->getNumResults() > 0)
58     p << " : " << op->getResultTypes();
59 }
60 
61 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
62 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
63   MLIRContext *context = parser.getContext();
64   auto int32Ty = IntegerType::get(context, 32);
65   auto int1Ty = IntegerType::get(context, 1);
66 
67   SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
68   Type type;
69   return failure(parser.parseOperandList(ops) ||
70                  parser.parseOptionalAttrDict(result.attributes) ||
71                  parser.parseColonType(type) ||
72                  parser.addTypeToList(type, result.types) ||
73                  parser.resolveOperands(ops, {int32Ty, int1Ty},
74                                         parser.getNameLoc(), result.operands));
75 }
76 
77 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
78 
79 LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
80   if (getCoordinates().empty() || getCoordinates().size() > 5)
81     return emitError("expects coordinates between 1 to 5 dimension");
82 
83   // Check for im2col mode
84   if (!getIm2colOffsets().empty()) {
85     if (getCoordinates().size() < 3)
86       return emitError(
87           "to use im2col mode, the tensor has to be at least 3-dimensional");
88     if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
89       return emitError(
90           "im2col offsets must be 2 less than number of coordinates");
91   }
92   return success();
93 }
94 
95 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
96   if (getCoordinates().size() > 5)
97     return emitError("Maximum 5 coordinates and dimension is supported.");
98   return success();
99 }
100 
101 LogicalResult CpAsyncOp::verify() {
102   if (getModifier() != LoadCacheModifierKind::CG &&
103       getModifier() != LoadCacheModifierKind::CA)
104     return emitError("Only CG and CA cache modifiers are supported.");
105   if (getSize() != 4 && getSize() != 8 && getSize() != 16)
106     return emitError("expected byte size to be either 4, 8 or 16.");
107   if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
108     return emitError("CG cache modifier is only support for 16 bytes copy.");
109   return success();
110 }
111 
112 // Given the element type of an operand and whether or not it is an accumulator,
113 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
114 // operand's element type.
115 std::optional<mlir::NVVM::MMATypes>
116 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
117   auto half2Type =
118       LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
119   if (operandElType.isF64())
120     return NVVM::MMATypes::f64;
121   if (operandElType.isF16() || operandElType == half2Type)
122     return NVVM::MMATypes::f16;
123   if (operandElType.isF32() && isAccumulator)
124     return NVVM::MMATypes::f32;
125   if (operandElType.isF32() && !isAccumulator)
126     return NVVM::MMATypes::tf32;
127   if (llvm::isa<IntegerType>(operandElType)) {
128     if (isAccumulator)
129       return NVVM::MMATypes::s32;
130     return std::nullopt;
131   }
132 
133   if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
134     if (structType.getBody().empty())
135       return std::nullopt;
136     return inferOperandMMAType(structType.getBody()[0], isAccumulator);
137   }
138 
139   return std::nullopt;
140 }
141 
142 static bool isInt4PtxType(MMATypes type) {
143   return (type == MMATypes::u4 || type == MMATypes::s4);
144 }
145 
146 static bool isInt8PtxType(MMATypes type) {
147   return (type == MMATypes::u8 || type == MMATypes::s8);
148 }
149 
150 static bool isIntegerPtxType(MMATypes type) {
151   return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
152          type == MMATypes::s32;
153 }
154 
155 MMATypes MmaOp::accumPtxType() {
156   std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
157       getODSOperands(2).getTypes().front(), /*isAccum=*/true);
158   assert(val.has_value() && "accumulator PTX type should always be inferrable");
159   return val.value();
160 }
161 
162 MMATypes MmaOp::resultPtxType() {
163   std::optional<mlir::NVVM::MMATypes> val =
164       inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
165   assert(val.has_value() && "result PTX type should always be inferrable");
166   return val.value();
167 }
168 
169 void MmaOp::print(OpAsmPrinter &p) {
170   SmallVector<Type, 4> regTypes;
171   struct OperandFragment {
172     StringRef operandName;
173     StringRef ptxTypeAttr;
174     SmallVector<Value, 4> regs;
175     explicit OperandFragment(StringRef name, StringRef ptxTypeName)
176         : operandName(name), ptxTypeAttr(ptxTypeName) {}
177   };
178 
179   std::array<OperandFragment, 3> frags{
180       OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
181       OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
182       OperandFragment("C", "")};
183   SmallVector<StringRef, 4> ignoreAttrNames{
184       mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
185 
186   for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
187     auto &frag = frags[fragIdx];
188     auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
189     for (auto operandIdx = varOperandSpec.first;
190          operandIdx < varOperandSpec.first + varOperandSpec.second;
191          operandIdx++) {
192       frag.regs.push_back(this->getOperand(operandIdx));
193       if (operandIdx == 0) {
194         regTypes.push_back(this->getOperand(operandIdx).getType());
195       }
196     }
197     std::optional<MMATypes> inferredType =
198         inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
199     if (inferredType)
200       ignoreAttrNames.push_back(frag.ptxTypeAttr);
201   }
202 
203   auto printMmaOperand = [&](const OperandFragment &frag) -> void {
204     p << " " << frag.operandName;
205     p << "[";
206     p.printOperands(frag.regs);
207     p << "] ";
208   };
209 
210   for (const auto &frag : frags) {
211     printMmaOperand(frag);
212   }
213 
214   p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
215 
216   // Print the types of the operands and result.
217   p << " : "
218     << "(";
219   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
220                                              frags[1].regs[0].getType(),
221                                              frags[2].regs[0].getType()},
222                         p);
223   p << ")";
224   p.printArrowTypeList(TypeRange{this->getRes().getType()});
225 }
226 
227 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
228                   ValueRange operandA, ValueRange operandB, ValueRange operandC,
229                   ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
230                   std::optional<MMAIntOverflow> intOverflow,
231                   std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
232                   std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
233 
234   assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
235   MLIRContext *ctx = builder.getContext();
236   result.addAttribute(
237       "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
238 
239   result.addOperands(operandA);
240   result.addOperands(operandB);
241   result.addOperands(operandC);
242 
243   if (multiplicandPtxTypes) {
244     result.addAttribute("multiplicandAPtxType",
245                         MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
246     result.addAttribute("multiplicandBPtxType",
247                         MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
248   } else {
249     if (auto res = inferOperandMMAType(operandA[0].getType(), false))
250       result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
251     if (auto res = inferOperandMMAType(operandB[0].getType(), false))
252       result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
253   }
254 
255   if (multiplicandLayouts) {
256     result.addAttribute("layoutA",
257                         MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
258     result.addAttribute("layoutB",
259                         MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
260   } else {
261     result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
262     result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
263   }
264 
265   if (intOverflow.has_value())
266     result.addAttribute("intOverflowBehavior",
267                         MMAIntOverflowAttr::get(ctx, *intOverflow));
268   if (b1Op.has_value())
269     result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
270 
271   result.addTypes(resultType);
272   result.addAttribute(
273       MmaOp::getOperandSegmentSizeAttr(),
274       builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
275                                     static_cast<int32_t>(operandB.size()),
276                                     static_cast<int32_t>(operandC.size())}));
277 }
278 
279 // <operation> :=
280 //   A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
281 //   attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
282 //     `->` type($res)
283 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
284   struct OperandFragment {
285     std::optional<MMATypes> elemtype;
286     SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
287     SmallVector<Type> regTypes;
288   };
289 
290   Builder &builder = parser.getBuilder();
291   std::array<OperandFragment, 4> frags;
292 
293   NamedAttrList namedAttributes;
294 
295   // A helper to parse the operand segments.
296   auto parseMmaOperand = [&](StringRef operandName,
297                              OperandFragment &frag) -> LogicalResult {
298     if (parser.parseKeyword(operandName).failed())
299       return failure();
300     if (parser
301             .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
302             .failed())
303       return failure();
304     return success();
305   };
306 
307   // Parse the operand segments.
308   if (parseMmaOperand("A", frags[0]).failed())
309     return failure();
310   if (parseMmaOperand("B", frags[1]).failed())
311     return failure();
312   if (parseMmaOperand("C", frags[2]).failed())
313     return failure();
314 
315   if (parser.parseOptionalAttrDict(namedAttributes).failed())
316     return failure();
317 
318   // Parse the type specification and resolve operands.
319   SmallVector<Type, 3> operandTypes;
320   if (failed(parser.parseColon()))
321     return failure();
322   if (failed(parser.parseLParen()))
323     return failure();
324   if (failed(parser.parseTypeList(operandTypes)))
325     return failure();
326   if (failed(parser.parseRParen()))
327     if (operandTypes.size() != 3)
328       return parser.emitError(
329           parser.getNameLoc(),
330           "expected one type for each operand segment but got " +
331               Twine(operandTypes.size()) + " types");
332   for (const auto &iter : llvm::enumerate(operandTypes)) {
333     auto &frag = frags[iter.index()];
334     frag.regTypes.resize(frag.regs.size(), iter.value());
335     if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
336                                       parser.getNameLoc(), result.operands)))
337       return failure();
338     frag.elemtype =
339         inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
340   }
341 
342   Type resultType;
343   if (parser.parseArrow() || parser.parseType(resultType))
344     return failure();
345   frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
346 
347   std::array<StringRef, 2> names{"multiplicandAPtxType",
348                                  "multiplicandBPtxType"};
349   for (unsigned idx = 0; idx < names.size(); idx++) {
350     const auto &frag = frags[idx];
351     std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
352     if (!frag.elemtype.has_value() && !attr.has_value()) {
353       return parser.emitError(
354           parser.getNameLoc(),
355           "attribute " + names[idx] +
356               " is not provided explicitly and cannot be inferred");
357     }
358     if (!attr.has_value())
359       result.addAttribute(
360           names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
361   }
362 
363   result.addTypes(resultType);
364   if (!namedAttributes.empty())
365     result.addAttributes(namedAttributes);
366   result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
367                       builder.getDenseI32ArrayAttr({
368                           static_cast<int32_t>(frags[0].regs.size()),
369                           static_cast<int32_t>(frags[1].regs.size()),
370                           static_cast<int32_t>(frags[2].regs.size()),
371                       }));
372   return success();
373 }
374 
375 LogicalResult MmaOp::verify() {
376   MLIRContext *context = getContext();
377   auto f16Ty = Float16Type::get(context);
378   auto i32Ty = IntegerType::get(context, 32);
379   auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
380   auto f32Ty = Float32Type::get(context);
381   auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
382       context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
383 
384   auto s32x4StructTy =
385       LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
386   auto f32x8StructTy =
387       LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
388   auto f16x2x2StructTy =
389       LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
390   auto f32x4StructTy =
391       LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
392   auto s32x2StructTy =
393       LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
394 
395   std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
396                                   getShapeAttr().getK()};
397 
398   // These variables define the set of allowed data types for matrices A, B, C,
399   // and result.
400   using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
401   using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
402   AllowedShapes allowedShapes;
403   AllowedTypes expectedA;
404   AllowedTypes expectedB;
405   AllowedTypes expectedC;
406   SmallVector<Type> expectedResult;
407 
408   // When M = 16, we just need to calculate the number of 8xk tiles, where
409   // k is a factor that depends on the data type.
410   if (mmaShape[0] == 16) {
411     int64_t kFactor;
412     Type multiplicandFragType;
413     switch (*getMultiplicandAPtxType()) {
414     case MMATypes::tf32:
415       kFactor = 4;
416       multiplicandFragType = i32Ty;
417       expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
418           context, {f32Ty, f32Ty, f32Ty, f32Ty}));
419       break;
420     case MMATypes::f16:
421     case MMATypes::bf16:
422       kFactor = 8;
423       multiplicandFragType = f16x2Ty;
424       expectedResult.push_back(f16x2x2StructTy);
425       expectedResult.push_back(f32x4StructTy);
426       break;
427     case MMATypes::s4:
428     case MMATypes::u4:
429       kFactor = 32;
430       break;
431     case MMATypes::b1:
432       kFactor = 128;
433       break;
434     case MMATypes::s8:
435     case MMATypes::u8:
436       kFactor = 16;
437       break;
438     default:
439       return emitError("invalid shape or multiplicand type: " +
440                        stringifyEnum(getMultiplicandAPtxType().value()));
441     }
442 
443     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
444       expectedResult.push_back(s32x4StructTy);
445       expectedC.emplace_back(4, i32Ty);
446       multiplicandFragType = i32Ty;
447     } else {
448       expectedC.emplace_back(2, f16x2Ty);
449       expectedC.emplace_back(4, f32Ty);
450     }
451 
452     int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
453     int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
454     expectedA.emplace_back(unitA, multiplicandFragType);
455     expectedB.emplace_back(unitB, multiplicandFragType);
456     allowedShapes.push_back({16, 8, kFactor});
457     allowedShapes.push_back({16, 8, kFactor * 2});
458   }
459 
460   // In the M=8 case, there is only 1 possible case per data type.
461   if (mmaShape[0] == 8) {
462     if (*getMultiplicandAPtxType() == MMATypes::f16) {
463       expectedA.emplace_back(2, f16x2Ty);
464       expectedB.emplace_back(2, f16x2Ty);
465       expectedResult.push_back(f16x2x4StructTy);
466       expectedResult.push_back(f32x8StructTy);
467       expectedC.emplace_back(4, f16x2Ty);
468       expectedC.emplace_back(8, f32Ty);
469       allowedShapes.push_back({8, 8, 4});
470     }
471     if (*getMultiplicandAPtxType() == MMATypes::f64) {
472       Type f64Ty = Float64Type::get(context);
473       expectedA.emplace_back(1, f64Ty);
474       expectedB.emplace_back(1, f64Ty);
475       expectedC.emplace_back(2, f64Ty);
476       // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
477       expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
478           context, SmallVector<Type>(2, f64Ty)));
479       allowedShapes.push_back({8, 8, 4});
480     }
481     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
482       expectedA.push_back({i32Ty});
483       expectedB.push_back({i32Ty});
484       expectedC.push_back({i32Ty, i32Ty});
485       expectedResult.push_back(s32x2StructTy);
486       if (isInt4PtxType(getMultiplicandAPtxType().value()))
487         allowedShapes.push_back({8, 8, 32});
488       if (isInt8PtxType(getMultiplicandAPtxType().value()))
489         allowedShapes.push_back({8, 8, 16});
490       if (getMultiplicandAPtxType().value() == MMATypes::b1)
491         allowedShapes.push_back({8, 8, 128});
492     }
493   }
494 
495   std::string errorMessage;
496   llvm::raw_string_ostream errorStream(errorMessage);
497 
498   // Check that we matched an existing shape/dtype combination.
499   if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
500       !llvm::is_contained(allowedShapes, mmaShape)) {
501     errorStream << "unimplemented variant for MMA shape <";
502     llvm::interleaveComma(mmaShape, errorStream);
503     errorStream << ">";
504     return emitOpError(errorMessage);
505   }
506 
507   // Verify the operand types for segments of A, B, and C operands.
508   std::array<StringRef, 3> operandNames{"A", "B", "C"};
509   for (const auto &iter : llvm::enumerate(
510            SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
511     auto spec = this->getODSOperandIndexAndLength(iter.index());
512     SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
513                                       operand_type_begin() + spec.first +
514                                           spec.second);
515     bool match = llvm::is_contained(iter.value(), operandTySeg);
516 
517     if (!match) {
518       errorStream << "Could not match types for the "
519                   << operandNames[iter.index()]
520                   << " operands; expected one of ";
521       for (const auto &x : iter.value()) {
522         errorStream << x.size() << "x" << x[0] << " ";
523       }
524       errorStream << "but got ";
525       llvm::interleaveComma(operandTySeg, errorStream);
526       return emitOpError(errorStream.str());
527     }
528   }
529 
530   // Check the result type
531   if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
532         return expectedResultType == getResult().getType();
533       })) {
534     errorStream
535         << "Could not match allowed types for the result; expected one of ";
536     llvm::interleaveComma(expectedResult, errorStream);
537     errorStream << " but got " << getResult().getType();
538     return emitOpError(errorStream.str());
539   }
540 
541   // Ensure that binary MMA variants have a b1 MMA operation defined.
542   if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
543     return emitOpError("op requires " + getB1OpAttrName().strref() +
544                        " attribute");
545   }
546 
547   // Ensure int4/int8 MMA variants specify the accum overflow behavior
548   // attribute.
549   if (isInt4PtxType(*getMultiplicandAPtxType()) ||
550       isInt8PtxType(*getMultiplicandAPtxType())) {
551     if (!getIntOverflowBehavior())
552       return emitOpError("op requires " +
553                          getIntOverflowBehaviorAttrName().strref() +
554                          " attribute");
555   }
556 
557   return success();
558 }
559 
560 LogicalResult ShflOp::verify() {
561   if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
562     return success();
563   auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
564   auto elementType = (type && type.getBody().size() == 2)
565                          ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
566                          : nullptr;
567   if (!elementType || elementType.getWidth() != 1)
568     return emitError("expected return type to be a two-element struct with "
569                      "i1 as the second element");
570   return success();
571 }
572 
573 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
574                                                    NVVM::MMAFrag frag, int nRow,
575                                                    int nCol,
576                                                    MLIRContext *context) {
577   unsigned numberElements = 0;
578   Type elementType;
579   OpBuilder builder(context);
580   Type f16x2 = VectorType::get(2, builder.getF16Type());
581   if (type == NVVM::MMATypes::f16) {
582     elementType = f16x2;
583     if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
584       numberElements = 8;
585     else
586       numberElements = 4;
587   } else if (type == NVVM::MMATypes::f32) {
588     elementType = builder.getF32Type();
589     numberElements = 8;
590   } else if (type == NVVM::MMATypes::tf32) {
591     elementType = builder.getI32Type();
592     numberElements = 4;
593   } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
594     elementType = builder.getI32Type();
595     int parallelSize = 0;
596     if (frag == NVVM::MMAFrag::a)
597       parallelSize = nRow;
598     if (frag == NVVM::MMAFrag::b)
599       parallelSize = nCol;
600 
601     // m == 16 && n == 16 && k == 16
602     if (parallelSize == 16)
603       numberElements = 2;
604     // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
605     else if (parallelSize == 8)
606       numberElements = 1;
607     else if (parallelSize == 32)
608       numberElements = 4;
609   } else if (type == NVVM::MMATypes::s32) {
610     elementType = builder.getI32Type();
611     numberElements = 8;
612   }
613   assert(numberElements != 0 && elementType != nullptr);
614   return std::make_pair(elementType, numberElements);
615 }
616 
617 static std::pair<mlir::Type, unsigned>
618 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
619                     int k, MLIRContext *context) {
620   int nRow, nCol;
621   if (frag == NVVM::MMAFrag::a) {
622     nRow = m;
623     nCol = k;
624   } else if (frag == NVVM::MMAFrag::b) {
625     nRow = k;
626     nCol = n;
627   } else {
628     nRow = m;
629     nCol = n;
630   }
631   assert(nRow && nCol);
632   return inferMMAType(type, frag, nRow, nCol, context);
633 }
634 
635 LogicalResult NVVM::WMMALoadOp::verify() {
636   unsigned addressSpace =
637       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
638   if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
639       addressSpace != NVVM::kSharedMemorySpace)
640     return emitOpError("expected source pointer in memory "
641                        "space 0, 1, 3");
642 
643   if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
644                                        getEltype(), getFrag()) == 0)
645     return emitOpError() << "invalid attribute combination";
646   std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
647       getEltype(), getFrag(), getM(), getN(), getK(), getContext());
648   Type dstType = LLVM::LLVMStructType::getLiteral(
649       getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
650   if (getType() != dstType)
651     return emitOpError("expected destination type is a structure of ")
652            << typeInfo.second << " elements of type " << typeInfo.first;
653   return success();
654 }
655 
656 LogicalResult NVVM::WMMAStoreOp::verify() {
657   unsigned addressSpace =
658       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
659   if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
660       addressSpace != NVVM::kSharedMemorySpace)
661     return emitOpError("expected operands to be a source pointer in memory "
662                        "space 0, 1, 3");
663 
664   if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
665                                         getEltype()) == 0)
666     return emitOpError() << "invalid attribute combination";
667   std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
668       getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
669   if (getArgs().size() != typeInfo.second)
670     return emitOpError() << "expected " << typeInfo.second << " data operands";
671   if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
672         return operands.getType() != typeInfo.first;
673       }))
674     return emitOpError() << "expected data operands of type " << typeInfo.first;
675   return success();
676 }
677 
678 LogicalResult NVVM::WMMAMmaOp::verify() {
679   if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
680                                       getLayoutB(), getEltypeA(),
681                                       getEltypeB()) == 0)
682     return emitOpError() << "invalid attribute combination";
683   std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
684       getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
685   std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
686       getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
687   std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
688       getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
689   SmallVector<Type, 32> arguments;
690   arguments.append(typeInfoA.second, typeInfoA.first);
691   arguments.append(typeInfoB.second, typeInfoB.first);
692   arguments.append(typeInfoC.second, typeInfoC.first);
693   unsigned numArgs = arguments.size();
694   if (getArgs().size() != numArgs)
695     return emitOpError() << "expected " << numArgs << " arguments";
696   for (unsigned i = 0; i < numArgs; i++) {
697     if (getArgs()[i].getType() != arguments[i])
698       return emitOpError() << "expected argument " << i << " to be of type "
699                            << arguments[i];
700   }
701   Type dstType = LLVM::LLVMStructType::getLiteral(
702       getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
703   if (getType() != dstType)
704     return emitOpError("expected destination type is a structure of ")
705            << typeInfoC.second << " elements of type " << typeInfoC.first;
706   return success();
707 }
708 
709 LogicalResult NVVM::LdMatrixOp::verify() {
710   unsigned addressSpace =
711       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
712   if (addressSpace != NVVM::kSharedMemorySpace)
713     return emitOpError("expected source pointer in memory space 3");
714 
715   if (getNum() != 1 && getNum() != 2 && getNum() != 4)
716     return emitOpError("expected num attribute to be 1, 2 or 4");
717 
718   Type i32 = IntegerType::get(getContext(), 32);
719   if (getNum() == 1 && getType() != i32)
720     return emitOpError("expected destination type is i32");
721   if (getNum() == 2 || getNum() == 4) {
722     Type dstType = LLVM::LLVMStructType::getLiteral(
723         getContext(), SmallVector<Type>(getNum(), i32));
724     if (getType() != dstType)
725       return emitOpError("expected destination type is a structure of ")
726              << getNum() << " elements of type i32";
727   }
728   return success();
729 }
730 
731 LogicalResult NVVM::StMatrixOp::verify() {
732   unsigned addressSpace =
733       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
734   if (addressSpace != NVVM::kSharedMemorySpace)
735     return emitOpError("expected source pointer in memory space 3");
736 
737   int numMatrix = getSources().size();
738   if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
739     return emitOpError("expected num attribute to be 1, 2 or 4");
740 
741   return success();
742 }
743 
744 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
745   if (typeA == NVVM::WGMMATypes::tf32)
746     return 8;
747   if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
748     return 16;
749   if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
750     return 32;
751   if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
752     return 32;
753   if (typeA == NVVM::WGMMATypes::b1)
754     return 256;
755   return failure();
756 }
757 
758 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
759                                      NVVM::WGMMATypes typeA,
760                                      NVVM::WGMMATypes typeB) {
761   switch (typeA) {
762   case NVVM::WGMMATypes::f16:
763     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
764         typeB == NVVM::WGMMATypes::f16)
765       return success();
766     break;
767   case NVVM::WGMMATypes::tf32:
768     if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
769       return success();
770     break;
771   case NVVM::WGMMATypes::u8:
772   case NVVM::WGMMATypes::s8:
773     if (typeD == NVVM::WGMMATypes::s32 &&
774         (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
775       return success();
776     break;
777   case NVVM::WGMMATypes::b1:
778     if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
779       return success();
780     break;
781   case NVVM::WGMMATypes::bf16:
782     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
783         typeB == NVVM::WGMMATypes::bf16)
784       return success();
785     break;
786   case NVVM::WGMMATypes::e4m3:
787   case NVVM::WGMMATypes::e5m2:
788     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
789         (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
790       return success();
791     break;
792   case WGMMATypes::f32:
793   case WGMMATypes::s32:
794     llvm_unreachable("unsupported input types");
795     break;
796   }
797   return failure();
798 }
799 
800 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
801   SmallVector<int> allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
802                                72,  80,  88,  96,  104, 112, 120, 128,
803                                136, 144, 152, 160, 168, 176, 184, 192,
804                                200, 208, 216, 224, 232, 240, 248, 256};
805   SmallVector<int> allowedNshort = {8,   16,  24,  32,  48,  64,
806                                     80,  96,  112, 128, 144, 160,
807                                     176, 192, 208, 224, 240, 256};
808   switch (typeA) {
809   case WGMMATypes::f16:
810   case WGMMATypes::tf32:
811   case WGMMATypes::bf16:
812   case WGMMATypes::e4m3:
813   case WGMMATypes::e5m2:
814     if (llvm::is_contained(allowedN, sizeN))
815       return success();
816     break;
817   case WGMMATypes::u8:
818   case WGMMATypes::s8:
819   case WGMMATypes::b1:
820     if (llvm::is_contained(allowedNshort, sizeN))
821       return success();
822     break;
823   case WGMMATypes::f32:
824   case WGMMATypes::s32:
825     llvm_unreachable("unsupported input types");
826     break;
827   }
828   return failure();
829 }
830 
831 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
832   Value outValue = getResults();
833   auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
834   if (!stype)
835     return emitOpError() << "expected results to be struct";
836   int outputSize = stype.getBody().size();
837   WGMMATypes typeD = getTypeD();
838   WGMMATypes typeA = getTypeA();
839   WGMMATypes typeB = getTypeB();
840 
841   for (Type t : stype.getBody()) {
842     if (t != stype.getBody().front())
843       return emitOpError()
844              << "all elements in struct must be same type but there is " << t;
845   }
846 
847   if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
848       typeD != WGMMATypes::s32) {
849     return emitOpError() << "does not support the given output type "
850                          << NVVM::stringifyWGMMATypes(typeD);
851   }
852   if (typeD == WGMMATypes::s32 &&
853       (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
854     return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
855   }
856 
857   if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
858     return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
859                          << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
860                          << NVVM::stringifyWGMMATypes(typeB)
861                          << ", it is not supported.";
862   }
863 
864   // Check M
865   if (getShape().getM() != 64)
866     return emitOpError() << "shape 'm' must be 64";
867 
868   // Check K
869   FailureOr<int> allowedK = getAllowedSizeK(typeA);
870   if (failed(allowedK) || allowedK.value() != getShape().getK())
871     return emitOpError() << "shape 'k' must be " << allowedK.value()
872                          << " for input type "
873                          << NVVM::stringifyWGMMATypes(typeA);
874 
875   // Check N
876   if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
877     return emitOpError() << "has input type "
878                          << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
879                          << getShape().getN() << ", it is not supported.";
880   }
881 
882   // Check transpose (only available for f16/bf16)
883   if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
884       (getLayoutA() == mlir::NVVM::MMALayout::col ||
885        getLayoutB() == mlir::NVVM::MMALayout::col)) {
886     return emitOpError()
887            << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
888            << " and layout_b = " << stringifyMMALayout(getLayoutB())
889            << " for input types " << stringifyWGMMATypes(typeA) << " and "
890            << stringifyWGMMATypes(typeB)
891            << " requires transpose. However, this is only supported for: "
892            << stringifyMMATypes(MMATypes::f16) << " and "
893            << stringifyMMATypes(MMATypes::bf16);
894   }
895 
896   // Check result registers
897   int expectedOutput = 0;
898   if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
899     expectedOutput = getShape().getN() / 2;
900   if (typeD == WGMMATypes::f16)
901     expectedOutput = getShape().getN() / 4;
902   if (outputSize != expectedOutput) {
903     return emitOpError() << "results " << expectedOutput
904                          << ", however output struct has " << outputSize
905                          << " elements";
906   }
907   // Check satfinite (only available for s32 accumulator)
908   if (typeD != WGMMATypes::s32 &&
909       getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
910           NVVM::MMAIntOverflow::satfinite) {
911     return emitOpError()
912            << " `satfinite` can be only used with s32 accumulator, however "
913               "the current accumulator is "
914            << NVVM::stringifyWGMMATypes(typeD);
915   }
916 
917   return success();
918 }
919 
920 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
921 
922   int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
923   bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
924 
925   StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
926 
927   int expectedOutputRegisters = 0;
928   if (getTypeD() == WGMMATypes::f16)
929     expectedOutputRegisters = getShape().getN() / 4;
930   else
931     expectedOutputRegisters = getShape().getN() / 2;
932 
933   std::string ptx;
934   llvm::raw_string_ostream ss(ptx);
935 
936   ss << "{\n"
937         ".reg .pred p;\n"
938         "setp.ne.b32 p, $"
939      << ((expectedOutputRegisters * 2) + 2)
940      << ", 0;\n"
941         "wgmma.mma_async.sync.aligned.m"
942      << m << "n" << n << "k" << k << "." << outputTypeName << "."
943      << stringifyWGMMATypes(getTypeA()) << "."
944      << stringifyWGMMATypes(getTypeB());
945   if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
946       NVVM::MMAIntOverflow::satfinite)
947     ss << ".satfinite";
948   ss << " {";
949   int regCnt = 0;
950   for (; regCnt < expectedOutputRegisters; ++regCnt) {
951     ss << "$" << regCnt;
952     if (regCnt != expectedOutputRegisters - 1)
953       ss << ", ";
954   }
955 
956   ss << "},";
957   // Need to map read/write registers correctly.
958   regCnt = (regCnt * 2);
959   ss << " $" << (regCnt) << ","
960      << " $" << (regCnt + 1) << ","
961      << " p";
962   if (getTypeD() != WGMMATypes::s32) {
963     ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
964   }
965   // Don't add transpose parameters unless needed.
966   if (isF16) {
967     ss << ", $" << (regCnt + 5) << ",  $" << (regCnt + 6);
968   }
969   ss << ";\n"
970      << "}\n";
971   ss.flush();
972   return ptx;
973 }
974 
975 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
976     RewriterBase &rewriter,
977     llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
978         &asmValues) {
979   bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
980   if (getResults())
981     asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
982   if (getInouts())
983     asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
984   asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
985   asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
986   asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
987                        mlir::NVVM::PTXRegisterMod::Read});
988   if (getTypeD() != WGMMATypes::s32) {
989     asmValues.push_back(
990         {makeConstantI32(rewriter,
991                          getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
992          mlir::NVVM::PTXRegisterMod::Read});
993     asmValues.push_back(
994         {makeConstantI32(rewriter,
995                          getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
996          mlir::NVVM::PTXRegisterMod::Read});
997   }
998   if (isF16) {
999     asmValues.push_back(
1000         {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1001          mlir::NVVM::PTXRegisterMod::Read});
1002     asmValues.push_back(
1003         {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1004          mlir::NVVM::PTXRegisterMod::Read});
1005   }
1006 }
1007 LogicalResult NVVM::FenceProxyOp::verify() {
1008   if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1009     return emitOpError() << "async_shared fence requires space attribute";
1010   }
1011   if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1012     return emitOpError() << "only async_shared fence can have space attribute";
1013   }
1014   return success();
1015 }
1016 
1017 LogicalResult NVVM::SetMaxRegisterOp::verify() {
1018   if (getRegCount() % 8)
1019     return emitOpError("new register size must be multiple of 8");
1020   if (getRegCount() < 24 || getRegCount() > 256)
1021     return emitOpError("new register size must be in between 24 to 256");
1022   return success();
1023 }
1024 
1025 LogicalResult NVVM::BarrierOp::verify() {
1026   if (getNumberOfThreads() && !getBarrierId())
1027     return emitOpError(
1028         "barrier id is missing, it should be set between 0 to 15");
1029   return success();
1030 }
1031 
1032 //===----------------------------------------------------------------------===//
1033 // NVVMDialect initialization, type parsing, and registration.
1034 //===----------------------------------------------------------------------===//
1035 
1036 // TODO: This should be the llvm.nvvm dialect once this is supported.
1037 void NVVMDialect::initialize() {
1038   addOperations<
1039 #define GET_OP_LIST
1040 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1041       >();
1042   addAttributes<
1043 #define GET_ATTRDEF_LIST
1044 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1045       >();
1046 
1047   // Support unknown operations because not all NVVM operations are
1048   // registered.
1049   allowUnknownOperations();
1050   declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1051   declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1052 }
1053 
1054 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1055                                                     NamedAttribute attr) {
1056   StringAttr attrName = attr.getName();
1057   // Kernel function attribute should be attached to functions.
1058   if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1059     auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op);
1060     if (!funcOp) {
1061       return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1062                              << "' attribute attached to unexpected op";
1063     }
1064     if (!funcOp.getResultTypes().empty()) {
1065       return op->emitError() << "kernel function cannot have results";
1066     }
1067   }
1068   // If maxntid and reqntid exist, it must be an array with max 3 dim
1069   if (attrName == NVVMDialect::getMaxntidAttrName() ||
1070       attrName == NVVMDialect::getReqntidAttrName()) {
1071     auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1072     if (!values || values.empty() || values.size() > 3)
1073       return op->emitError()
1074              << "'" << attrName
1075              << "' attribute must be integer array with maximum 3 index";
1076   }
1077   // If minctasm and maxnreg exist, it must be an integer attribute
1078   if (attrName == NVVMDialect::getMinctasmAttrName() ||
1079       attrName == NVVMDialect::getMaxnregAttrName()) {
1080     if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1081       return op->emitError()
1082              << "'" << attrName << "' attribute must be integer constant";
1083   }
1084 
1085   return success();
1086 }
1087 
1088 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1089                                                     unsigned regionIndex,
1090                                                     unsigned argIndex,
1091                                                     NamedAttribute argAttr) {
1092   auto funcOp = dyn_cast<FunctionOpInterface>(op);
1093   if (!funcOp)
1094     return success();
1095 
1096   bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1097   StringAttr attrName = argAttr.getName();
1098   if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1099     if (!isKernel) {
1100       return op->emitError()
1101              << "'" << attrName
1102              << "' attribute must be present only on kernel arguments";
1103     }
1104     if (!isa<UnitAttr>(argAttr.getValue()))
1105       return op->emitError() << "'" << attrName << "' must be a unit attribute";
1106     if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1107       return op->emitError()
1108              << "'" << attrName
1109              << "' attribute requires the argument to also have attribute '"
1110              << LLVM::LLVMDialect::getByValAttrName() << "'";
1111     }
1112   }
1113 
1114   return success();
1115 }
1116 
1117 //===----------------------------------------------------------------------===//
1118 // NVVM target attribute.
1119 //===----------------------------------------------------------------------===//
1120 LogicalResult
1121 NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1122                        int optLevel, StringRef triple, StringRef chip,
1123                        StringRef features, DictionaryAttr flags,
1124                        ArrayAttr files) {
1125   if (optLevel < 0 || optLevel > 3) {
1126     emitError() << "The optimization level must be a number between 0 and 3.";
1127     return failure();
1128   }
1129   if (triple.empty()) {
1130     emitError() << "The target triple cannot be empty.";
1131     return failure();
1132   }
1133   if (chip.empty()) {
1134     emitError() << "The target chip cannot be empty.";
1135     return failure();
1136   }
1137   if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1138         return attr && mlir::isa<StringAttr>(attr);
1139       })) {
1140     emitError() << "All the elements in the `link` array must be strings.";
1141     return failure();
1142   }
1143   return success();
1144 }
1145 
1146 #define GET_OP_CLASSES
1147 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1148 
1149 #define GET_ATTRDEF_CLASSES
1150 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1151