xref: /llvm-project/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (revision 1b23ebe0770aaf85f37e085b53067066d2d99cc8)
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18 
19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
21 #include "mlir/Dialect/Utils/StaticValueUtils.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
26 #include "mlir/IR/DialectImplementation.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
29 #include "mlir/IR/OperationSupport.h"
30 #include "mlir/IR/Types.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/AsmParser/Parser.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/SourceMgr.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include <cassert>
41 #include <optional>
42 #include <string>
43 
44 using namespace mlir;
45 using namespace NVVM;
46 
47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
49 
50 //===----------------------------------------------------------------------===//
51 // Printing/parsing for NVVM ops
52 //===----------------------------------------------------------------------===//
53 
54 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
55   p << " " << op->getOperands();
56   if (op->getNumResults() > 0)
57     p << " : " << op->getResultTypes();
58 }
59 
60 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
61 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
62   MLIRContext *context = parser.getContext();
63   auto int32Ty = IntegerType::get(context, 32);
64   auto int1Ty = IntegerType::get(context, 1);
65 
66   SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
67   Type type;
68   return failure(parser.parseOperandList(ops) ||
69                  parser.parseOptionalAttrDict(result.attributes) ||
70                  parser.parseColonType(type) ||
71                  parser.addTypeToList(type, result.types) ||
72                  parser.resolveOperands(ops, {int32Ty, int1Ty},
73                                         parser.getNameLoc(), result.operands));
74 }
75 
76 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
77 
78 // This verifier is shared across:
79 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) and
80 // CpAsyncBulkTensorPrefetchOp (TMA Prefetch) Ops.
81 static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
82                                                      size_t numIm2ColOffsets,
83                                                      Location loc) {
84   if (tensorDims < 1 || tensorDims > 5)
85     return emitError(loc, "expects coordinates between 1 to 5 dimension");
86 
87   if (numIm2ColOffsets) {
88     if (tensorDims < 3)
89       return emitError(
90           loc,
91           "to use im2col mode, the tensor has to be at least 3-dimensional");
92     if (tensorDims != (numIm2ColOffsets + 2))
93       return emitError(
94           loc, "im2col offsets must be 2 less than number of coordinates");
95   }
96   return success();
97 }
98 
99 LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
100   return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
101                                          getIm2colOffsets().size(), getLoc());
102 }
103 
104 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
105   if (getCoordinates().size() > 5)
106     return emitError("Maximum 5 coordinates and dimension is supported.");
107   return success();
108 }
109 
110 LogicalResult CpAsyncOp::verify() {
111   if (getModifier() != LoadCacheModifierKind::CG &&
112       getModifier() != LoadCacheModifierKind::CA)
113     return emitError("Only CG and CA cache modifiers are supported.");
114   if (getSize() != 4 && getSize() != 8 && getSize() != 16)
115     return emitError("expected byte size to be either 4, 8 or 16.");
116   if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
117     return emitError("CG cache modifier is only support for 16 bytes copy.");
118   return success();
119 }
120 
121 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
122   return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
123                                          getIm2colOffsets().size(), getLoc());
124 }
125 
126 // Given the element type of an operand and whether or not it is an accumulator,
127 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
128 // operand's element type.
129 std::optional<mlir::NVVM::MMATypes>
130 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
131   auto half2Type =
132       LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
133   if (operandElType.isF64())
134     return NVVM::MMATypes::f64;
135   if (operandElType.isF16() || operandElType == half2Type)
136     return NVVM::MMATypes::f16;
137   if (operandElType.isF32() && isAccumulator)
138     return NVVM::MMATypes::f32;
139   if (operandElType.isF32() && !isAccumulator)
140     return NVVM::MMATypes::tf32;
141   if (llvm::isa<IntegerType>(operandElType)) {
142     if (isAccumulator)
143       return NVVM::MMATypes::s32;
144     return std::nullopt;
145   }
146 
147   if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
148     if (structType.getBody().empty())
149       return std::nullopt;
150     return inferOperandMMAType(structType.getBody()[0], isAccumulator);
151   }
152 
153   return std::nullopt;
154 }
155 
156 static bool isInt4PtxType(MMATypes type) {
157   return (type == MMATypes::u4 || type == MMATypes::s4);
158 }
159 
160 static bool isInt8PtxType(MMATypes type) {
161   return (type == MMATypes::u8 || type == MMATypes::s8);
162 }
163 
164 static bool isIntegerPtxType(MMATypes type) {
165   return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
166          type == MMATypes::s32;
167 }
168 
169 MMATypes MmaOp::accumPtxType() {
170   std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
171       getODSOperands(2).getTypes().front(), /*isAccum=*/true);
172   assert(val.has_value() && "accumulator PTX type should always be inferrable");
173   return val.value();
174 }
175 
176 MMATypes MmaOp::resultPtxType() {
177   std::optional<mlir::NVVM::MMATypes> val =
178       inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
179   assert(val.has_value() && "result PTX type should always be inferrable");
180   return val.value();
181 }
182 
183 void MmaOp::print(OpAsmPrinter &p) {
184   SmallVector<Type, 4> regTypes;
185   struct OperandFragment {
186     StringRef operandName;
187     StringRef ptxTypeAttr;
188     SmallVector<Value, 4> regs;
189     explicit OperandFragment(StringRef name, StringRef ptxTypeName)
190         : operandName(name), ptxTypeAttr(ptxTypeName) {}
191   };
192 
193   std::array<OperandFragment, 3> frags{
194       OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
195       OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
196       OperandFragment("C", "")};
197   SmallVector<StringRef, 4> ignoreAttrNames{
198       mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
199 
200   for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
201     auto &frag = frags[fragIdx];
202     auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
203     for (auto operandIdx = varOperandSpec.first;
204          operandIdx < varOperandSpec.first + varOperandSpec.second;
205          operandIdx++) {
206       frag.regs.push_back(this->getOperand(operandIdx));
207       if (operandIdx == 0) {
208         regTypes.push_back(this->getOperand(operandIdx).getType());
209       }
210     }
211     std::optional<MMATypes> inferredType =
212         inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
213     if (inferredType)
214       ignoreAttrNames.push_back(frag.ptxTypeAttr);
215   }
216 
217   auto printMmaOperand = [&](const OperandFragment &frag) -> void {
218     p << " " << frag.operandName;
219     p << "[";
220     p.printOperands(frag.regs);
221     p << "] ";
222   };
223 
224   for (const auto &frag : frags) {
225     printMmaOperand(frag);
226   }
227 
228   p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
229 
230   // Print the types of the operands and result.
231   p << " : " << "(";
232   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
233                                              frags[1].regs[0].getType(),
234                                              frags[2].regs[0].getType()},
235                         p);
236   p << ")";
237   p.printArrowTypeList(TypeRange{this->getRes().getType()});
238 }
239 
240 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
241                   ValueRange operandA, ValueRange operandB, ValueRange operandC,
242                   ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
243                   std::optional<MMAIntOverflow> intOverflow,
244                   std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
245                   std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
246 
247   assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
248   MLIRContext *ctx = builder.getContext();
249   result.addAttribute(
250       "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
251 
252   result.addOperands(operandA);
253   result.addOperands(operandB);
254   result.addOperands(operandC);
255 
256   if (multiplicandPtxTypes) {
257     result.addAttribute("multiplicandAPtxType",
258                         MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
259     result.addAttribute("multiplicandBPtxType",
260                         MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
261   } else {
262     if (auto res = inferOperandMMAType(operandA[0].getType(), false))
263       result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
264     if (auto res = inferOperandMMAType(operandB[0].getType(), false))
265       result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
266   }
267 
268   if (multiplicandLayouts) {
269     result.addAttribute("layoutA",
270                         MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
271     result.addAttribute("layoutB",
272                         MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
273   } else {
274     result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
275     result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
276   }
277 
278   if (intOverflow.has_value())
279     result.addAttribute("intOverflowBehavior",
280                         MMAIntOverflowAttr::get(ctx, *intOverflow));
281   if (b1Op.has_value())
282     result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
283 
284   result.addTypes(resultType);
285   result.addAttribute(
286       MmaOp::getOperandSegmentSizeAttr(),
287       builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
288                                     static_cast<int32_t>(operandB.size()),
289                                     static_cast<int32_t>(operandC.size())}));
290 }
291 
292 // <operation> :=
293 //   A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
294 //   attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
295 //     `->` type($res)
296 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
297   struct OperandFragment {
298     std::optional<MMATypes> elemtype;
299     SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
300     SmallVector<Type> regTypes;
301   };
302 
303   Builder &builder = parser.getBuilder();
304   std::array<OperandFragment, 4> frags;
305 
306   NamedAttrList namedAttributes;
307 
308   // A helper to parse the operand segments.
309   auto parseMmaOperand = [&](StringRef operandName,
310                              OperandFragment &frag) -> LogicalResult {
311     if (parser.parseKeyword(operandName).failed())
312       return failure();
313     if (parser
314             .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
315             .failed())
316       return failure();
317     return success();
318   };
319 
320   // Parse the operand segments.
321   if (parseMmaOperand("A", frags[0]).failed())
322     return failure();
323   if (parseMmaOperand("B", frags[1]).failed())
324     return failure();
325   if (parseMmaOperand("C", frags[2]).failed())
326     return failure();
327 
328   if (parser.parseOptionalAttrDict(namedAttributes).failed())
329     return failure();
330 
331   // Parse the type specification and resolve operands.
332   SmallVector<Type, 3> operandTypes;
333   if (failed(parser.parseColon()))
334     return failure();
335   if (failed(parser.parseLParen()))
336     return failure();
337   if (failed(parser.parseTypeList(operandTypes)))
338     return failure();
339   if (failed(parser.parseRParen()))
340     if (operandTypes.size() != 3)
341       return parser.emitError(
342           parser.getNameLoc(),
343           "expected one type for each operand segment but got " +
344               Twine(operandTypes.size()) + " types");
345   for (const auto &iter : llvm::enumerate(operandTypes)) {
346     auto &frag = frags[iter.index()];
347     frag.regTypes.resize(frag.regs.size(), iter.value());
348     if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
349                                       parser.getNameLoc(), result.operands)))
350       return failure();
351     frag.elemtype =
352         inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
353   }
354 
355   Type resultType;
356   if (parser.parseArrow() || parser.parseType(resultType))
357     return failure();
358   frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
359 
360   std::array<StringRef, 2> names{"multiplicandAPtxType",
361                                  "multiplicandBPtxType"};
362   for (unsigned idx = 0; idx < names.size(); idx++) {
363     const auto &frag = frags[idx];
364     std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
365     if (!frag.elemtype.has_value() && !attr.has_value()) {
366       return parser.emitError(
367           parser.getNameLoc(),
368           "attribute " + names[idx] +
369               " is not provided explicitly and cannot be inferred");
370     }
371     if (!attr.has_value())
372       result.addAttribute(
373           names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
374   }
375 
376   result.addTypes(resultType);
377   if (!namedAttributes.empty())
378     result.addAttributes(namedAttributes);
379   result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
380                       builder.getDenseI32ArrayAttr({
381                           static_cast<int32_t>(frags[0].regs.size()),
382                           static_cast<int32_t>(frags[1].regs.size()),
383                           static_cast<int32_t>(frags[2].regs.size()),
384                       }));
385   return success();
386 }
387 
388 LogicalResult MmaOp::verify() {
389   MLIRContext *context = getContext();
390   auto f16Ty = Float16Type::get(context);
391   auto i32Ty = IntegerType::get(context, 32);
392   auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
393   auto f32Ty = Float32Type::get(context);
394   auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
395       context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
396 
397   auto s32x4StructTy =
398       LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
399   auto f32x8StructTy =
400       LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
401   auto f16x2x2StructTy =
402       LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
403   auto f32x4StructTy =
404       LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
405   auto s32x2StructTy =
406       LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
407 
408   std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
409                                   getShapeAttr().getK()};
410 
411   // These variables define the set of allowed data types for matrices A, B, C,
412   // and result.
413   using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
414   using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
415   AllowedShapes allowedShapes;
416   AllowedTypes expectedA;
417   AllowedTypes expectedB;
418   AllowedTypes expectedC;
419   SmallVector<Type> expectedResult;
420 
421   // When M = 16, we just need to calculate the number of 8xk tiles, where
422   // k is a factor that depends on the data type.
423   if (mmaShape[0] == 16) {
424     int64_t kFactor;
425     Type multiplicandFragType;
426     switch (*getMultiplicandAPtxType()) {
427     case MMATypes::tf32:
428       kFactor = 4;
429       multiplicandFragType = i32Ty;
430       expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
431           context, {f32Ty, f32Ty, f32Ty, f32Ty}));
432       break;
433     case MMATypes::f16:
434     case MMATypes::bf16:
435       kFactor = 8;
436       multiplicandFragType = f16x2Ty;
437       expectedResult.push_back(f16x2x2StructTy);
438       expectedResult.push_back(f32x4StructTy);
439       break;
440     case MMATypes::s4:
441     case MMATypes::u4:
442       kFactor = 32;
443       break;
444     case MMATypes::b1:
445       kFactor = 128;
446       break;
447     case MMATypes::s8:
448     case MMATypes::u8:
449       kFactor = 16;
450       break;
451     default:
452       return emitError("invalid shape or multiplicand type: " +
453                        stringifyEnum(getMultiplicandAPtxType().value()));
454     }
455 
456     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
457       expectedResult.push_back(s32x4StructTy);
458       expectedC.emplace_back(4, i32Ty);
459       multiplicandFragType = i32Ty;
460     } else {
461       expectedC.emplace_back(2, f16x2Ty);
462       expectedC.emplace_back(4, f32Ty);
463     }
464 
465     int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
466     int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
467     expectedA.emplace_back(unitA, multiplicandFragType);
468     expectedB.emplace_back(unitB, multiplicandFragType);
469     allowedShapes.push_back({16, 8, kFactor});
470     allowedShapes.push_back({16, 8, kFactor * 2});
471   }
472 
473   // In the M=8 case, there is only 1 possible case per data type.
474   if (mmaShape[0] == 8) {
475     if (*getMultiplicandAPtxType() == MMATypes::f16) {
476       expectedA.emplace_back(2, f16x2Ty);
477       expectedB.emplace_back(2, f16x2Ty);
478       expectedResult.push_back(f16x2x4StructTy);
479       expectedResult.push_back(f32x8StructTy);
480       expectedC.emplace_back(4, f16x2Ty);
481       expectedC.emplace_back(8, f32Ty);
482       allowedShapes.push_back({8, 8, 4});
483     }
484     if (*getMultiplicandAPtxType() == MMATypes::f64) {
485       Type f64Ty = Float64Type::get(context);
486       expectedA.emplace_back(1, f64Ty);
487       expectedB.emplace_back(1, f64Ty);
488       expectedC.emplace_back(2, f64Ty);
489       // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
490       expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
491           context, SmallVector<Type>(2, f64Ty)));
492       allowedShapes.push_back({8, 8, 4});
493     }
494     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
495       expectedA.push_back({i32Ty});
496       expectedB.push_back({i32Ty});
497       expectedC.push_back({i32Ty, i32Ty});
498       expectedResult.push_back(s32x2StructTy);
499       if (isInt4PtxType(getMultiplicandAPtxType().value()))
500         allowedShapes.push_back({8, 8, 32});
501       if (isInt8PtxType(getMultiplicandAPtxType().value()))
502         allowedShapes.push_back({8, 8, 16});
503       if (getMultiplicandAPtxType().value() == MMATypes::b1)
504         allowedShapes.push_back({8, 8, 128});
505     }
506   }
507 
508   std::string errorMessage;
509   llvm::raw_string_ostream errorStream(errorMessage);
510 
511   // Check that we matched an existing shape/dtype combination.
512   if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
513       !llvm::is_contained(allowedShapes, mmaShape)) {
514     errorStream << "unimplemented variant for MMA shape <";
515     llvm::interleaveComma(mmaShape, errorStream);
516     errorStream << ">";
517     return emitOpError(errorMessage);
518   }
519 
520   // Verify the operand types for segments of A, B, and C operands.
521   std::array<StringRef, 3> operandNames{"A", "B", "C"};
522   for (const auto &iter : llvm::enumerate(
523            SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
524     auto spec = this->getODSOperandIndexAndLength(iter.index());
525     SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
526                                       operand_type_begin() + spec.first +
527                                           spec.second);
528     bool match = llvm::is_contained(iter.value(), operandTySeg);
529 
530     if (!match) {
531       errorStream << "Could not match types for the "
532                   << operandNames[iter.index()]
533                   << " operands; expected one of ";
534       for (const auto &x : iter.value()) {
535         errorStream << x.size() << "x" << x[0] << " ";
536       }
537       errorStream << "but got ";
538       llvm::interleaveComma(operandTySeg, errorStream);
539       return emitOpError(errorMessage);
540     }
541   }
542 
543   // Check the result type
544   if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
545         return expectedResultType == getResult().getType();
546       })) {
547     errorStream
548         << "Could not match allowed types for the result; expected one of ";
549     llvm::interleaveComma(expectedResult, errorStream);
550     errorStream << " but got " << getResult().getType();
551     return emitOpError(errorMessage);
552   }
553 
554   // Ensure that binary MMA variants have a b1 MMA operation defined.
555   if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
556     return emitOpError("op requires " + getB1OpAttrName().strref() +
557                        " attribute");
558   }
559 
560   // Ensure int4/int8 MMA variants specify the accum overflow behavior
561   // attribute.
562   if (isInt4PtxType(*getMultiplicandAPtxType()) ||
563       isInt8PtxType(*getMultiplicandAPtxType())) {
564     if (!getIntOverflowBehavior())
565       return emitOpError("op requires " +
566                          getIntOverflowBehaviorAttrName().strref() +
567                          " attribute");
568   }
569 
570   return success();
571 }
572 
573 LogicalResult ShflOp::verify() {
574   if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
575     return success();
576   auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
577   auto elementType = (type && type.getBody().size() == 2)
578                          ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
579                          : nullptr;
580   if (!elementType || elementType.getWidth() != 1)
581     return emitError("expected return type to be a two-element struct with "
582                      "i1 as the second element");
583   return success();
584 }
585 
586 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
587                                                    NVVM::MMAFrag frag, int nRow,
588                                                    int nCol,
589                                                    MLIRContext *context) {
590   unsigned numberElements = 0;
591   Type elementType;
592   OpBuilder builder(context);
593   Type f16x2 = VectorType::get(2, builder.getF16Type());
594   if (type == NVVM::MMATypes::f16) {
595     elementType = f16x2;
596     if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
597       numberElements = 8;
598     else
599       numberElements = 4;
600   } else if (type == NVVM::MMATypes::f32) {
601     elementType = builder.getF32Type();
602     numberElements = 8;
603   } else if (type == NVVM::MMATypes::tf32) {
604     elementType = builder.getI32Type();
605     numberElements = 4;
606   } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
607     elementType = builder.getI32Type();
608     int parallelSize = 0;
609     if (frag == NVVM::MMAFrag::a)
610       parallelSize = nRow;
611     if (frag == NVVM::MMAFrag::b)
612       parallelSize = nCol;
613 
614     // m == 16 && n == 16 && k == 16
615     if (parallelSize == 16)
616       numberElements = 2;
617     // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
618     else if (parallelSize == 8)
619       numberElements = 1;
620     else if (parallelSize == 32)
621       numberElements = 4;
622   } else if (type == NVVM::MMATypes::s32) {
623     elementType = builder.getI32Type();
624     numberElements = 8;
625   }
626   assert(numberElements != 0 && elementType != nullptr);
627   return std::make_pair(elementType, numberElements);
628 }
629 
630 static std::pair<mlir::Type, unsigned>
631 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
632                     int k, MLIRContext *context) {
633   int nRow, nCol;
634   if (frag == NVVM::MMAFrag::a) {
635     nRow = m;
636     nCol = k;
637   } else if (frag == NVVM::MMAFrag::b) {
638     nRow = k;
639     nCol = n;
640   } else {
641     nRow = m;
642     nCol = n;
643   }
644   assert(nRow && nCol);
645   return inferMMAType(type, frag, nRow, nCol, context);
646 }
647 
648 LogicalResult NVVM::WMMALoadOp::verify() {
649   unsigned addressSpace =
650       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
651   if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
652       addressSpace != NVVM::kSharedMemorySpace)
653     return emitOpError("expected source pointer in memory "
654                        "space 0, 1, 3");
655 
656   if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
657                                        getEltype(), getFrag()) == 0)
658     return emitOpError() << "invalid attribute combination";
659   std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
660       getEltype(), getFrag(), getM(), getN(), getK(), getContext());
661   Type dstType = LLVM::LLVMStructType::getLiteral(
662       getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
663   if (getType() != dstType)
664     return emitOpError("expected destination type is a structure of ")
665            << typeInfo.second << " elements of type " << typeInfo.first;
666   return success();
667 }
668 
669 LogicalResult NVVM::WMMAStoreOp::verify() {
670   unsigned addressSpace =
671       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
672   if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
673       addressSpace != NVVM::kSharedMemorySpace)
674     return emitOpError("expected operands to be a source pointer in memory "
675                        "space 0, 1, 3");
676 
677   if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
678                                         getEltype()) == 0)
679     return emitOpError() << "invalid attribute combination";
680   std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
681       getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
682   if (getArgs().size() != typeInfo.second)
683     return emitOpError() << "expected " << typeInfo.second << " data operands";
684   if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
685         return operands.getType() != typeInfo.first;
686       }))
687     return emitOpError() << "expected data operands of type " << typeInfo.first;
688   return success();
689 }
690 
691 LogicalResult NVVM::WMMAMmaOp::verify() {
692   if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
693                                       getLayoutB(), getEltypeA(),
694                                       getEltypeB()) == 0)
695     return emitOpError() << "invalid attribute combination";
696   std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
697       getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
698   std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
699       getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
700   std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
701       getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
702   SmallVector<Type, 32> arguments;
703   arguments.append(typeInfoA.second, typeInfoA.first);
704   arguments.append(typeInfoB.second, typeInfoB.first);
705   arguments.append(typeInfoC.second, typeInfoC.first);
706   unsigned numArgs = arguments.size();
707   if (getArgs().size() != numArgs)
708     return emitOpError() << "expected " << numArgs << " arguments";
709   for (unsigned i = 0; i < numArgs; i++) {
710     if (getArgs()[i].getType() != arguments[i])
711       return emitOpError() << "expected argument " << i << " to be of type "
712                            << arguments[i];
713   }
714   Type dstType = LLVM::LLVMStructType::getLiteral(
715       getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
716   if (getType() != dstType)
717     return emitOpError("expected destination type is a structure of ")
718            << typeInfoC.second << " elements of type " << typeInfoC.first;
719   return success();
720 }
721 
722 LogicalResult NVVM::LdMatrixOp::verify() {
723   unsigned addressSpace =
724       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
725   if (addressSpace != NVVM::kSharedMemorySpace)
726     return emitOpError("expected source pointer in memory space 3");
727 
728   if (getNum() != 1 && getNum() != 2 && getNum() != 4)
729     return emitOpError("expected num attribute to be 1, 2 or 4");
730 
731   Type i32 = IntegerType::get(getContext(), 32);
732   if (getNum() == 1 && getType() != i32)
733     return emitOpError("expected destination type is i32");
734   if (getNum() == 2 || getNum() == 4) {
735     Type dstType = LLVM::LLVMStructType::getLiteral(
736         getContext(), SmallVector<Type>(getNum(), i32));
737     if (getType() != dstType)
738       return emitOpError("expected destination type is a structure of ")
739              << getNum() << " elements of type i32";
740   }
741   return success();
742 }
743 
744 LogicalResult NVVM::StMatrixOp::verify() {
745   unsigned addressSpace =
746       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
747   if (addressSpace != NVVM::kSharedMemorySpace)
748     return emitOpError("expected source pointer in memory space 3");
749 
750   int numMatrix = getSources().size();
751   if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
752     return emitOpError("expected num attribute to be 1, 2 or 4");
753 
754   return success();
755 }
756 
757 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
758   if (typeA == NVVM::WGMMATypes::tf32)
759     return 8;
760   if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
761     return 16;
762   if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
763     return 32;
764   if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
765     return 32;
766   if (typeA == NVVM::WGMMATypes::b1)
767     return 256;
768   return failure();
769 }
770 
771 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
772                                      NVVM::WGMMATypes typeA,
773                                      NVVM::WGMMATypes typeB) {
774   switch (typeA) {
775   case NVVM::WGMMATypes::f16:
776     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
777         typeB == NVVM::WGMMATypes::f16)
778       return success();
779     break;
780   case NVVM::WGMMATypes::tf32:
781     if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
782       return success();
783     break;
784   case NVVM::WGMMATypes::u8:
785   case NVVM::WGMMATypes::s8:
786     if (typeD == NVVM::WGMMATypes::s32 &&
787         (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
788       return success();
789     break;
790   case NVVM::WGMMATypes::b1:
791     if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
792       return success();
793     break;
794   case NVVM::WGMMATypes::bf16:
795     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
796         typeB == NVVM::WGMMATypes::bf16)
797       return success();
798     break;
799   case NVVM::WGMMATypes::e4m3:
800   case NVVM::WGMMATypes::e5m2:
801     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
802         (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
803       return success();
804     break;
805   case WGMMATypes::f32:
806   case WGMMATypes::s32:
807     llvm_unreachable("unsupported input types");
808     break;
809   }
810   return failure();
811 }
812 
813 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
814   SmallVector<int> allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
815                                72,  80,  88,  96,  104, 112, 120, 128,
816                                136, 144, 152, 160, 168, 176, 184, 192,
817                                200, 208, 216, 224, 232, 240, 248, 256};
818   SmallVector<int> allowedNshort = {8,   16,  24,  32,  48,  64,
819                                     80,  96,  112, 128, 144, 160,
820                                     176, 192, 208, 224, 240, 256};
821   switch (typeA) {
822   case WGMMATypes::f16:
823   case WGMMATypes::tf32:
824   case WGMMATypes::bf16:
825   case WGMMATypes::e4m3:
826   case WGMMATypes::e5m2:
827     if (llvm::is_contained(allowedN, sizeN))
828       return success();
829     break;
830   case WGMMATypes::u8:
831   case WGMMATypes::s8:
832   case WGMMATypes::b1:
833     if (llvm::is_contained(allowedNshort, sizeN))
834       return success();
835     break;
836   case WGMMATypes::f32:
837   case WGMMATypes::s32:
838     llvm_unreachable("unsupported input types");
839     break;
840   }
841   return failure();
842 }
843 
844 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
845   Value outValue = getResults();
846   auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
847   if (!stype)
848     return emitOpError() << "expected results to be struct";
849   int outputSize = stype.getBody().size();
850   WGMMATypes typeD = getTypeD();
851   WGMMATypes typeA = getTypeA();
852   WGMMATypes typeB = getTypeB();
853 
854   for (Type t : stype.getBody()) {
855     if (t != stype.getBody().front())
856       return emitOpError()
857              << "all elements in struct must be same type but there is " << t;
858   }
859 
860   if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
861       typeD != WGMMATypes::s32) {
862     return emitOpError() << "does not support the given output type "
863                          << NVVM::stringifyWGMMATypes(typeD);
864   }
865   if (typeD == WGMMATypes::s32 &&
866       (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
867     return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
868   }
869 
870   if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
871     return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
872                          << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
873                          << NVVM::stringifyWGMMATypes(typeB)
874                          << ", it is not supported.";
875   }
876 
877   // Check M
878   if (getShape().getM() != 64)
879     return emitOpError() << "shape 'm' must be 64";
880 
881   // Check K
882   FailureOr<int> allowedK = getAllowedSizeK(typeA);
883   if (failed(allowedK) || allowedK.value() != getShape().getK())
884     return emitOpError() << "shape 'k' must be " << allowedK.value()
885                          << " for input type "
886                          << NVVM::stringifyWGMMATypes(typeA);
887 
888   // Check N
889   if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
890     return emitOpError() << "has input type "
891                          << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
892                          << getShape().getN() << ", it is not supported.";
893   }
894 
895   // Check transpose (only available for f16/bf16)
896   // Matrices A should be stored in row-major and B in column-major.
897   // Only f16/bf16 matrices can be stored in either column-major or row-major
898   // by setting the tranpose value(imm-trans-a,imm-trans-b) in PTX code.
899   if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
900       (getLayoutA() == mlir::NVVM::MMALayout::col ||
901        getLayoutB() == mlir::NVVM::MMALayout::row)) {
902     return emitOpError()
903            << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
904            << " and layout_b = " << stringifyMMALayout(getLayoutB())
905            << " for input types " << stringifyWGMMATypes(typeA) << " and "
906            << stringifyWGMMATypes(typeB)
907            << " requires transpose. However, this is only supported for: "
908            << stringifyMMATypes(MMATypes::f16) << " and "
909            << stringifyMMATypes(MMATypes::bf16);
910   }
911 
912   // Check result registers
913   int expectedOutput = 0;
914   if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
915     expectedOutput = getShape().getN() / 2;
916   if (typeD == WGMMATypes::f16)
917     expectedOutput = getShape().getN() / 4;
918   if (outputSize != expectedOutput) {
919     return emitOpError() << "results " << expectedOutput
920                          << ", however output struct has " << outputSize
921                          << " elements";
922   }
923   // Check satfinite (only available for s32 accumulator)
924   if (typeD != WGMMATypes::s32 &&
925       getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
926           NVVM::MMAIntOverflow::satfinite) {
927     return emitOpError()
928            << " `satfinite` can be only used with s32 accumulator, however "
929               "the current accumulator is "
930            << NVVM::stringifyWGMMATypes(typeD);
931   }
932 
933   return success();
934 }
935 
936 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
937 
938   int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
939   bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
940 
941   StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
942 
943   int expectedOutputRegisters = 0;
944   if (getTypeD() == WGMMATypes::f16)
945     expectedOutputRegisters = getShape().getN() / 4;
946   else
947     expectedOutputRegisters = getShape().getN() / 2;
948 
949   std::string ptx;
950   llvm::raw_string_ostream ss(ptx);
951 
952   ss << "{\n"
953         ".reg .pred p;\n"
954         "setp.ne.b32 p, $"
955      << ((expectedOutputRegisters * 2) + 2)
956      << ", 0;\n"
957         "wgmma.mma_async.sync.aligned.m"
958      << m << "n" << n << "k" << k << "." << outputTypeName << "."
959      << stringifyWGMMATypes(getTypeA()) << "."
960      << stringifyWGMMATypes(getTypeB());
961   if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
962       NVVM::MMAIntOverflow::satfinite)
963     ss << ".satfinite";
964   ss << " {";
965   int regCnt = 0;
966   for (; regCnt < expectedOutputRegisters; ++regCnt) {
967     ss << "$" << regCnt;
968     if (regCnt != expectedOutputRegisters - 1)
969       ss << ", ";
970   }
971 
972   ss << "},";
973   // Need to map read/write registers correctly.
974   regCnt = (regCnt * 2);
975   ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
976   if (getTypeD() != WGMMATypes::s32) {
977     ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
978   }
979   // Don't add transpose parameters unless needed.
980   if (isF16) {
981     ss << ", $" << (regCnt + 5) << ",  $" << (regCnt + 6);
982   }
983   ss << ";\n"
984      << "}\n";
985   return ptx;
986 }
987 
988 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
989     RewriterBase &rewriter,
990     llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
991         &asmValues) {
992   bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
993   if (getResults())
994     asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
995   if (getInouts())
996     asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
997   asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
998   asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
999   asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1000                        mlir::NVVM::PTXRegisterMod::Read});
1001   if (getTypeD() != WGMMATypes::s32) {
1002     asmValues.push_back(
1003         {makeConstantI32(rewriter,
1004                          getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1005          mlir::NVVM::PTXRegisterMod::Read});
1006     asmValues.push_back(
1007         {makeConstantI32(rewriter,
1008                          getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1009          mlir::NVVM::PTXRegisterMod::Read});
1010   }
1011   if (isF16) {
1012     asmValues.push_back(
1013         {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1014          mlir::NVVM::PTXRegisterMod::Read});
1015     asmValues.push_back(
1016         {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1017          mlir::NVVM::PTXRegisterMod::Read});
1018   }
1019 }
1020 LogicalResult NVVM::FenceProxyOp::verify() {
1021   if (getKind() == NVVM::ProxyKind::TENSORMAP)
1022     return emitOpError() << "tensormap proxy is not a supported proxy kind";
1023   if (getKind() == NVVM::ProxyKind::GENERIC)
1024     return emitOpError() << "generic proxy not a supported proxy kind";
1025   if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1026     return emitOpError() << "async_shared fence requires space attribute";
1027   }
1028   if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1029     return emitOpError() << "only async_shared fence can have space attribute";
1030   }
1031   return success();
1032 }
1033 
1034 LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1035   if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1036     return emitOpError("uni-directional proxies only support generic for "
1037                        "from_proxy attribute");
1038 
1039   if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1040     return emitOpError("uni-directional proxies only support tensormap "
1041                        "for to_proxy attribute");
1042 
1043   return success();
1044 }
1045 
1046 LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1047   if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1048     return emitOpError("uni-directional proxies only support generic for "
1049                        "from_proxy attribute");
1050 
1051   if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1052     return emitOpError("uni-directional proxies only support tensormap "
1053                        "for to_proxy attribute");
1054 
1055   return success();
1056 }
1057 
1058 LogicalResult NVVM::SetMaxRegisterOp::verify() {
1059   if (getRegCount() % 8)
1060     return emitOpError("new register size must be multiple of 8");
1061   if (getRegCount() < 24 || getRegCount() > 256)
1062     return emitOpError("new register size must be in between 24 to 256");
1063   return success();
1064 }
1065 
1066 LogicalResult NVVM::BarrierOp::verify() {
1067   if (getNumberOfThreads() && !getBarrierId())
1068     return emitOpError(
1069         "barrier id is missing, it should be set between 0 to 15");
1070   return success();
1071 }
1072 
1073 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1074                                                                 bool isIm2Col) {
1075   switch (tensorDims) {
1076   case 1:
1077     return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1078   case 2:
1079     return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1080   case 3:
1081     return isIm2Col
1082                ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1083                : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1084   case 4:
1085     return isIm2Col
1086                ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1087                : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1088   case 5:
1089     return isIm2Col
1090                ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1091                : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1092   default:
1093     llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1094   }
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // NVVMDialect initialization, type parsing, and registration.
1099 //===----------------------------------------------------------------------===//
1100 
1101 // TODO: This should be the llvm.nvvm dialect once this is supported.
1102 void NVVMDialect::initialize() {
1103   addOperations<
1104 #define GET_OP_LIST
1105 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1106       >();
1107   addAttributes<
1108 #define GET_ATTRDEF_LIST
1109 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1110       >();
1111 
1112   // Support unknown operations because not all NVVM operations are
1113   // registered.
1114   allowUnknownOperations();
1115   declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1116   declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1117 }
1118 
1119 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1120                                                     NamedAttribute attr) {
1121   StringAttr attrName = attr.getName();
1122   // Kernel function attribute should be attached to functions.
1123   if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1124     if (!isa<LLVM::LLVMFuncOp>(op)) {
1125       return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1126                              << "' attribute attached to unexpected op";
1127     }
1128   }
1129   // If maxntid and reqntid exist, it must be an array with max 3 dim
1130   if (attrName == NVVMDialect::getMaxntidAttrName() ||
1131       attrName == NVVMDialect::getReqntidAttrName()) {
1132     auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1133     if (!values || values.empty() || values.size() > 3)
1134       return op->emitError()
1135              << "'" << attrName
1136              << "' attribute must be integer array with maximum 3 index";
1137   }
1138   // If minctasm and maxnreg exist, it must be an integer attribute
1139   if (attrName == NVVMDialect::getMinctasmAttrName() ||
1140       attrName == NVVMDialect::getMaxnregAttrName()) {
1141     if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1142       return op->emitError()
1143              << "'" << attrName << "' attribute must be integer constant";
1144   }
1145 
1146   return success();
1147 }
1148 
1149 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1150                                                     unsigned regionIndex,
1151                                                     unsigned argIndex,
1152                                                     NamedAttribute argAttr) {
1153   auto funcOp = dyn_cast<FunctionOpInterface>(op);
1154   if (!funcOp)
1155     return success();
1156 
1157   bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1158   StringAttr attrName = argAttr.getName();
1159   if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1160     if (!isKernel) {
1161       return op->emitError()
1162              << "'" << attrName
1163              << "' attribute must be present only on kernel arguments";
1164     }
1165     if (!isa<UnitAttr>(argAttr.getValue()))
1166       return op->emitError() << "'" << attrName << "' must be a unit attribute";
1167     if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1168       return op->emitError()
1169              << "'" << attrName
1170              << "' attribute requires the argument to also have attribute '"
1171              << LLVM::LLVMDialect::getByValAttrName() << "'";
1172     }
1173   }
1174 
1175   return success();
1176 }
1177 
1178 //===----------------------------------------------------------------------===//
1179 // NVVM target attribute.
1180 //===----------------------------------------------------------------------===//
1181 LogicalResult
1182 NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1183                        int optLevel, StringRef triple, StringRef chip,
1184                        StringRef features, DictionaryAttr flags,
1185                        ArrayAttr files) {
1186   if (optLevel < 0 || optLevel > 3) {
1187     emitError() << "The optimization level must be a number between 0 and 3.";
1188     return failure();
1189   }
1190   if (triple.empty()) {
1191     emitError() << "The target triple cannot be empty.";
1192     return failure();
1193   }
1194   if (chip.empty()) {
1195     emitError() << "The target chip cannot be empty.";
1196     return failure();
1197   }
1198   if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1199         return attr && mlir::isa<StringAttr>(attr);
1200       })) {
1201     emitError() << "All the elements in the `link` array must be strings.";
1202     return failure();
1203   }
1204   return success();
1205 }
1206 
1207 #define GET_OP_CLASSES
1208 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1209 
1210 #define GET_ATTRDEF_CLASSES
1211 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1212