xref: /llvm-project/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (revision d03f35f9b6d031d6a9375d90ccf7cc285f8e4b79)
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18 
19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
21 #include "mlir/Dialect/Utils/StaticValueUtils.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
26 #include "mlir/IR/DialectImplementation.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
29 #include "mlir/IR/OperationSupport.h"
30 #include "mlir/IR/Types.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/AsmParser/Parser.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/SourceMgr.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include <cassert>
41 #include <optional>
42 #include <string>
43 
44 using namespace mlir;
45 using namespace NVVM;
46 
47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
49 
50 //===----------------------------------------------------------------------===//
51 // Printing/parsing for NVVM ops
52 //===----------------------------------------------------------------------===//
53 
54 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
55   p << " " << op->getOperands();
56   if (op->getNumResults() > 0)
57     p << " : " << op->getResultTypes();
58 }
59 
60 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
61 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
62   MLIRContext *context = parser.getContext();
63   auto int32Ty = IntegerType::get(context, 32);
64   auto int1Ty = IntegerType::get(context, 1);
65 
66   SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
67   Type type;
68   return failure(parser.parseOperandList(ops) ||
69                  parser.parseOptionalAttrDict(result.attributes) ||
70                  parser.parseColonType(type) ||
71                  parser.addTypeToList(type, result.types) ||
72                  parser.resolveOperands(ops, {int32Ty, int1Ty},
73                                         parser.getNameLoc(), result.operands));
74 }
75 
76 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
77 
78 // This verifier is shared among the following Ops:
79 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
80 // CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
81 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
82 static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
83                                                      bool isIm2Col,
84                                                      size_t numIm2ColOffsets,
85                                                      Location loc) {
86   if (tensorDims < 1 || tensorDims > 5)
87     return emitError(loc, "expects coordinates between 1 to 5 dimension");
88 
89   // For Im2Col mode, there are two constraints:
90   if (isIm2Col) {
91     // 1. Tensor must always be at least 3-d.
92     if (tensorDims < 3)
93       return emitError(
94           loc,
95           "to use im2col mode, the tensor has to be at least 3-dimensional");
96     // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
97     if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
98       return emitError(
99           loc, "im2col offsets must be 2 less than number of coordinates");
100   }
101   return success();
102 }
103 
104 LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
105   size_t numIm2ColOffsets = getIm2colOffsets().size();
106   bool isIm2Col = numIm2ColOffsets > 0;
107   return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
108                                          numIm2ColOffsets, getLoc());
109 }
110 
111 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
112   if (getCoordinates().size() > 5)
113     return emitError("Maximum 5 coordinates and dimension is supported.");
114   return success();
115 }
116 
117 LogicalResult CpAsyncOp::verify() {
118   if (getModifier() != LoadCacheModifierKind::CG &&
119       getModifier() != LoadCacheModifierKind::CA)
120     return emitError("Only CG and CA cache modifiers are supported.");
121   if (getSize() != 4 && getSize() != 8 && getSize() != 16)
122     return emitError("expected byte size to be either 4, 8 or 16.");
123   if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
124     return emitError("CG cache modifier is only support for 16 bytes copy.");
125   return success();
126 }
127 
128 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
129   size_t numIm2ColOffsets = getIm2colOffsets().size();
130   bool isIm2Col = numIm2ColOffsets > 0;
131   return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
132                                          numIm2ColOffsets, getLoc());
133 }
134 
135 LogicalResult CpAsyncBulkTensorReduceOp::verify() {
136   bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
137   return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
138                                          getLoc());
139 }
140 
141 // Given the element type of an operand and whether or not it is an accumulator,
142 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
143 // operand's element type.
144 std::optional<mlir::NVVM::MMATypes>
145 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
146   auto half2Type =
147       LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
148   if (operandElType.isF64())
149     return NVVM::MMATypes::f64;
150   if (operandElType.isF16() || operandElType == half2Type)
151     return NVVM::MMATypes::f16;
152   if (operandElType.isF32() && isAccumulator)
153     return NVVM::MMATypes::f32;
154   if (operandElType.isF32() && !isAccumulator)
155     return NVVM::MMATypes::tf32;
156   if (llvm::isa<IntegerType>(operandElType)) {
157     if (isAccumulator)
158       return NVVM::MMATypes::s32;
159     return std::nullopt;
160   }
161 
162   if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
163     if (structType.getBody().empty())
164       return std::nullopt;
165     return inferOperandMMAType(structType.getBody()[0], isAccumulator);
166   }
167 
168   return std::nullopt;
169 }
170 
171 static bool isInt4PtxType(MMATypes type) {
172   return (type == MMATypes::u4 || type == MMATypes::s4);
173 }
174 
175 static bool isInt8PtxType(MMATypes type) {
176   return (type == MMATypes::u8 || type == MMATypes::s8);
177 }
178 
179 static bool isIntegerPtxType(MMATypes type) {
180   return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
181          type == MMATypes::s32;
182 }
183 
184 MMATypes MmaOp::accumPtxType() {
185   std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
186       getODSOperands(2).getTypes().front(), /*isAccum=*/true);
187   assert(val.has_value() && "accumulator PTX type should always be inferrable");
188   return val.value();
189 }
190 
191 MMATypes MmaOp::resultPtxType() {
192   std::optional<mlir::NVVM::MMATypes> val =
193       inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
194   assert(val.has_value() && "result PTX type should always be inferrable");
195   return val.value();
196 }
197 
198 void MmaOp::print(OpAsmPrinter &p) {
199   SmallVector<Type, 4> regTypes;
200   struct OperandFragment {
201     StringRef operandName;
202     StringRef ptxTypeAttr;
203     SmallVector<Value, 4> regs;
204     explicit OperandFragment(StringRef name, StringRef ptxTypeName)
205         : operandName(name), ptxTypeAttr(ptxTypeName) {}
206   };
207 
208   std::array<OperandFragment, 3> frags{
209       OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
210       OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
211       OperandFragment("C", "")};
212   SmallVector<StringRef, 4> ignoreAttrNames{
213       mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
214 
215   for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
216     auto &frag = frags[fragIdx];
217     auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
218     for (auto operandIdx = varOperandSpec.first;
219          operandIdx < varOperandSpec.first + varOperandSpec.second;
220          operandIdx++) {
221       frag.regs.push_back(this->getOperand(operandIdx));
222       if (operandIdx == 0) {
223         regTypes.push_back(this->getOperand(operandIdx).getType());
224       }
225     }
226     std::optional<MMATypes> inferredType =
227         inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
228     if (inferredType)
229       ignoreAttrNames.push_back(frag.ptxTypeAttr);
230   }
231 
232   auto printMmaOperand = [&](const OperandFragment &frag) -> void {
233     p << " " << frag.operandName;
234     p << "[";
235     p.printOperands(frag.regs);
236     p << "] ";
237   };
238 
239   for (const auto &frag : frags) {
240     printMmaOperand(frag);
241   }
242 
243   p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
244 
245   // Print the types of the operands and result.
246   p << " : " << "(";
247   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
248                                              frags[1].regs[0].getType(),
249                                              frags[2].regs[0].getType()},
250                         p);
251   p << ")";
252   p.printArrowTypeList(TypeRange{this->getRes().getType()});
253 }
254 
255 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
256                   ValueRange operandA, ValueRange operandB, ValueRange operandC,
257                   ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
258                   std::optional<MMAIntOverflow> intOverflow,
259                   std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
260                   std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
261 
262   assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
263   MLIRContext *ctx = builder.getContext();
264   result.addAttribute(
265       "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
266 
267   result.addOperands(operandA);
268   result.addOperands(operandB);
269   result.addOperands(operandC);
270 
271   if (multiplicandPtxTypes) {
272     result.addAttribute("multiplicandAPtxType",
273                         MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
274     result.addAttribute("multiplicandBPtxType",
275                         MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
276   } else {
277     if (auto res = inferOperandMMAType(operandA[0].getType(), false))
278       result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
279     if (auto res = inferOperandMMAType(operandB[0].getType(), false))
280       result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
281   }
282 
283   if (multiplicandLayouts) {
284     result.addAttribute("layoutA",
285                         MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
286     result.addAttribute("layoutB",
287                         MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
288   } else {
289     result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
290     result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
291   }
292 
293   if (intOverflow.has_value())
294     result.addAttribute("intOverflowBehavior",
295                         MMAIntOverflowAttr::get(ctx, *intOverflow));
296   if (b1Op.has_value())
297     result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
298 
299   result.addTypes(resultType);
300   result.addAttribute(
301       MmaOp::getOperandSegmentSizeAttr(),
302       builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
303                                     static_cast<int32_t>(operandB.size()),
304                                     static_cast<int32_t>(operandC.size())}));
305 }
306 
307 // <operation> :=
308 //   A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
309 //   attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
310 //     `->` type($res)
311 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
312   struct OperandFragment {
313     std::optional<MMATypes> elemtype;
314     SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
315     SmallVector<Type> regTypes;
316   };
317 
318   Builder &builder = parser.getBuilder();
319   std::array<OperandFragment, 4> frags;
320 
321   NamedAttrList namedAttributes;
322 
323   // A helper to parse the operand segments.
324   auto parseMmaOperand = [&](StringRef operandName,
325                              OperandFragment &frag) -> LogicalResult {
326     if (parser.parseKeyword(operandName).failed())
327       return failure();
328     if (parser
329             .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
330             .failed())
331       return failure();
332     return success();
333   };
334 
335   // Parse the operand segments.
336   if (parseMmaOperand("A", frags[0]).failed())
337     return failure();
338   if (parseMmaOperand("B", frags[1]).failed())
339     return failure();
340   if (parseMmaOperand("C", frags[2]).failed())
341     return failure();
342 
343   if (parser.parseOptionalAttrDict(namedAttributes).failed())
344     return failure();
345 
346   // Parse the type specification and resolve operands.
347   SmallVector<Type, 3> operandTypes;
348   if (failed(parser.parseColon()))
349     return failure();
350   if (failed(parser.parseLParen()))
351     return failure();
352   if (failed(parser.parseTypeList(operandTypes)))
353     return failure();
354   if (failed(parser.parseRParen()))
355     if (operandTypes.size() != 3)
356       return parser.emitError(
357           parser.getNameLoc(),
358           "expected one type for each operand segment but got " +
359               Twine(operandTypes.size()) + " types");
360   for (const auto &iter : llvm::enumerate(operandTypes)) {
361     auto &frag = frags[iter.index()];
362     frag.regTypes.resize(frag.regs.size(), iter.value());
363     if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
364                                       parser.getNameLoc(), result.operands)))
365       return failure();
366     frag.elemtype =
367         inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
368   }
369 
370   Type resultType;
371   if (parser.parseArrow() || parser.parseType(resultType))
372     return failure();
373   frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
374 
375   std::array<StringRef, 2> names{"multiplicandAPtxType",
376                                  "multiplicandBPtxType"};
377   for (unsigned idx = 0; idx < names.size(); idx++) {
378     const auto &frag = frags[idx];
379     std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
380     if (!frag.elemtype.has_value() && !attr.has_value()) {
381       return parser.emitError(
382           parser.getNameLoc(),
383           "attribute " + names[idx] +
384               " is not provided explicitly and cannot be inferred");
385     }
386     if (!attr.has_value())
387       result.addAttribute(
388           names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
389   }
390 
391   result.addTypes(resultType);
392   if (!namedAttributes.empty())
393     result.addAttributes(namedAttributes);
394   result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
395                       builder.getDenseI32ArrayAttr({
396                           static_cast<int32_t>(frags[0].regs.size()),
397                           static_cast<int32_t>(frags[1].regs.size()),
398                           static_cast<int32_t>(frags[2].regs.size()),
399                       }));
400   return success();
401 }
402 
403 LogicalResult MmaOp::verify() {
404   MLIRContext *context = getContext();
405   auto f16Ty = Float16Type::get(context);
406   auto i32Ty = IntegerType::get(context, 32);
407   auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
408   auto f32Ty = Float32Type::get(context);
409   auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
410       context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
411 
412   auto s32x4StructTy =
413       LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
414   auto f32x8StructTy =
415       LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
416   auto f16x2x2StructTy =
417       LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
418   auto f32x4StructTy =
419       LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
420   auto s32x2StructTy =
421       LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
422 
423   std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
424                                   getShapeAttr().getK()};
425 
426   // These variables define the set of allowed data types for matrices A, B, C,
427   // and result.
428   using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
429   using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
430   AllowedShapes allowedShapes;
431   AllowedTypes expectedA;
432   AllowedTypes expectedB;
433   AllowedTypes expectedC;
434   SmallVector<Type> expectedResult;
435 
436   // When M = 16, we just need to calculate the number of 8xk tiles, where
437   // k is a factor that depends on the data type.
438   if (mmaShape[0] == 16) {
439     int64_t kFactor;
440     Type multiplicandFragType;
441     switch (*getMultiplicandAPtxType()) {
442     case MMATypes::tf32:
443       kFactor = 4;
444       multiplicandFragType = i32Ty;
445       expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
446           context, {f32Ty, f32Ty, f32Ty, f32Ty}));
447       break;
448     case MMATypes::bf16:
449       kFactor = 8;
450       multiplicandFragType = i32Ty;
451       expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
452           context, {f32Ty, f32Ty, f32Ty, f32Ty}));
453       break;
454     case MMATypes::f16:
455       kFactor = 8;
456       multiplicandFragType = f16x2Ty;
457       expectedResult.push_back(f16x2x2StructTy);
458       expectedResult.push_back(f32x4StructTy);
459       break;
460     case MMATypes::s4:
461     case MMATypes::u4:
462       kFactor = 32;
463       break;
464     case MMATypes::b1:
465       kFactor = 128;
466       break;
467     case MMATypes::s8:
468     case MMATypes::u8:
469       kFactor = 16;
470       break;
471     default:
472       return emitError("invalid shape or multiplicand type: " +
473                        stringifyEnum(getMultiplicandAPtxType().value()));
474     }
475 
476     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
477       expectedResult.push_back(s32x4StructTy);
478       expectedC.emplace_back(4, i32Ty);
479       multiplicandFragType = i32Ty;
480     } else {
481       expectedC.emplace_back(2, f16x2Ty);
482       expectedC.emplace_back(4, f32Ty);
483     }
484 
485     int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
486     int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
487     expectedA.emplace_back(unitA, multiplicandFragType);
488     expectedB.emplace_back(unitB, multiplicandFragType);
489     allowedShapes.push_back({16, 8, kFactor});
490     allowedShapes.push_back({16, 8, kFactor * 2});
491   }
492 
493   // In the M=8 case, there is only 1 possible case per data type.
494   if (mmaShape[0] == 8) {
495     if (*getMultiplicandAPtxType() == MMATypes::f16) {
496       expectedA.emplace_back(2, f16x2Ty);
497       expectedB.emplace_back(2, f16x2Ty);
498       expectedResult.push_back(f16x2x4StructTy);
499       expectedResult.push_back(f32x8StructTy);
500       expectedC.emplace_back(4, f16x2Ty);
501       expectedC.emplace_back(8, f32Ty);
502       allowedShapes.push_back({8, 8, 4});
503     }
504     if (*getMultiplicandAPtxType() == MMATypes::f64) {
505       Type f64Ty = Float64Type::get(context);
506       expectedA.emplace_back(1, f64Ty);
507       expectedB.emplace_back(1, f64Ty);
508       expectedC.emplace_back(2, f64Ty);
509       // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
510       expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
511           context, SmallVector<Type>(2, f64Ty)));
512       allowedShapes.push_back({8, 8, 4});
513     }
514     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
515       expectedA.push_back({i32Ty});
516       expectedB.push_back({i32Ty});
517       expectedC.push_back({i32Ty, i32Ty});
518       expectedResult.push_back(s32x2StructTy);
519       if (isInt4PtxType(getMultiplicandAPtxType().value()))
520         allowedShapes.push_back({8, 8, 32});
521       if (isInt8PtxType(getMultiplicandAPtxType().value()))
522         allowedShapes.push_back({8, 8, 16});
523       if (getMultiplicandAPtxType().value() == MMATypes::b1)
524         allowedShapes.push_back({8, 8, 128});
525     }
526   }
527 
528   std::string errorMessage;
529   llvm::raw_string_ostream errorStream(errorMessage);
530 
531   // Check that we matched an existing shape/dtype combination.
532   if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
533       !llvm::is_contained(allowedShapes, mmaShape)) {
534     errorStream << "unimplemented variant for MMA shape <";
535     llvm::interleaveComma(mmaShape, errorStream);
536     errorStream << ">";
537     return emitOpError(errorMessage);
538   }
539 
540   // Verify the operand types for segments of A, B, and C operands.
541   std::array<StringRef, 3> operandNames{"A", "B", "C"};
542   for (const auto &iter : llvm::enumerate(
543            SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
544     auto spec = this->getODSOperandIndexAndLength(iter.index());
545     SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
546                                       operand_type_begin() + spec.first +
547                                           spec.second);
548     bool match = llvm::is_contained(iter.value(), operandTySeg);
549 
550     if (!match) {
551       errorStream << "Could not match types for the "
552                   << operandNames[iter.index()]
553                   << " operands; expected one of ";
554       for (const auto &x : iter.value()) {
555         errorStream << x.size() << "x" << x[0] << " ";
556       }
557       errorStream << "but got ";
558       llvm::interleaveComma(operandTySeg, errorStream);
559       return emitOpError(errorMessage);
560     }
561   }
562 
563   // Check the result type
564   if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
565         return expectedResultType == getResult().getType();
566       })) {
567     errorStream
568         << "Could not match allowed types for the result; expected one of ";
569     llvm::interleaveComma(expectedResult, errorStream);
570     errorStream << " but got " << getResult().getType();
571     return emitOpError(errorMessage);
572   }
573 
574   // Ensure that binary MMA variants have a b1 MMA operation defined.
575   if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
576     return emitOpError("op requires " + getB1OpAttrName().strref() +
577                        " attribute");
578   }
579 
580   // Ensure int4/int8 MMA variants specify the accum overflow behavior
581   // attribute.
582   if (isInt4PtxType(*getMultiplicandAPtxType()) ||
583       isInt8PtxType(*getMultiplicandAPtxType())) {
584     if (!getIntOverflowBehavior())
585       return emitOpError("op requires " +
586                          getIntOverflowBehaviorAttrName().strref() +
587                          " attribute");
588   }
589 
590   return success();
591 }
592 
593 LogicalResult ShflOp::verify() {
594   if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
595     return success();
596   auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
597   auto elementType = (type && type.getBody().size() == 2)
598                          ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
599                          : nullptr;
600   if (!elementType || elementType.getWidth() != 1)
601     return emitError("expected return type to be a two-element struct with "
602                      "i1 as the second element");
603   return success();
604 }
605 
606 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
607                                                    NVVM::MMAFrag frag, int nRow,
608                                                    int nCol,
609                                                    MLIRContext *context) {
610   unsigned numberElements = 0;
611   Type elementType;
612   OpBuilder builder(context);
613   Type f16x2 = VectorType::get(2, builder.getF16Type());
614   if (type == NVVM::MMATypes::f16) {
615     elementType = f16x2;
616     if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
617       numberElements = 8;
618     else
619       numberElements = 4;
620   } else if (type == NVVM::MMATypes::f32) {
621     elementType = builder.getF32Type();
622     numberElements = 8;
623   } else if (type == NVVM::MMATypes::tf32) {
624     elementType = builder.getI32Type();
625     numberElements = 4;
626   } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
627     elementType = builder.getI32Type();
628     int parallelSize = 0;
629     if (frag == NVVM::MMAFrag::a)
630       parallelSize = nRow;
631     if (frag == NVVM::MMAFrag::b)
632       parallelSize = nCol;
633 
634     // m == 16 && n == 16 && k == 16
635     if (parallelSize == 16)
636       numberElements = 2;
637     // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
638     else if (parallelSize == 8)
639       numberElements = 1;
640     else if (parallelSize == 32)
641       numberElements = 4;
642   } else if (type == NVVM::MMATypes::s32) {
643     elementType = builder.getI32Type();
644     numberElements = 8;
645   }
646   assert(numberElements != 0 && elementType != nullptr);
647   return std::make_pair(elementType, numberElements);
648 }
649 
650 static std::pair<mlir::Type, unsigned>
651 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
652                     int k, MLIRContext *context) {
653   int nRow, nCol;
654   if (frag == NVVM::MMAFrag::a) {
655     nRow = m;
656     nCol = k;
657   } else if (frag == NVVM::MMAFrag::b) {
658     nRow = k;
659     nCol = n;
660   } else {
661     nRow = m;
662     nCol = n;
663   }
664   assert(nRow && nCol);
665   return inferMMAType(type, frag, nRow, nCol, context);
666 }
667 
668 LogicalResult NVVM::WMMALoadOp::verify() {
669   unsigned addressSpace =
670       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
671   if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
672       addressSpace != NVVM::kSharedMemorySpace)
673     return emitOpError("expected source pointer in memory "
674                        "space 0, 1, 3");
675 
676   if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
677                                        getEltype(), getFrag()) == 0)
678     return emitOpError() << "invalid attribute combination";
679   std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
680       getEltype(), getFrag(), getM(), getN(), getK(), getContext());
681   Type dstType = LLVM::LLVMStructType::getLiteral(
682       getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
683   if (getType() != dstType)
684     return emitOpError("expected destination type is a structure of ")
685            << typeInfo.second << " elements of type " << typeInfo.first;
686   return success();
687 }
688 
689 LogicalResult NVVM::WMMAStoreOp::verify() {
690   unsigned addressSpace =
691       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
692   if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
693       addressSpace != NVVM::kSharedMemorySpace)
694     return emitOpError("expected operands to be a source pointer in memory "
695                        "space 0, 1, 3");
696 
697   if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
698                                         getEltype()) == 0)
699     return emitOpError() << "invalid attribute combination";
700   std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
701       getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
702   if (getArgs().size() != typeInfo.second)
703     return emitOpError() << "expected " << typeInfo.second << " data operands";
704   if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
705         return operands.getType() != typeInfo.first;
706       }))
707     return emitOpError() << "expected data operands of type " << typeInfo.first;
708   return success();
709 }
710 
711 LogicalResult NVVM::WMMAMmaOp::verify() {
712   if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
713                                       getLayoutB(), getEltypeA(),
714                                       getEltypeB()) == 0)
715     return emitOpError() << "invalid attribute combination";
716   std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
717       getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
718   std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
719       getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
720   std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
721       getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
722   SmallVector<Type, 32> arguments;
723   arguments.append(typeInfoA.second, typeInfoA.first);
724   arguments.append(typeInfoB.second, typeInfoB.first);
725   arguments.append(typeInfoC.second, typeInfoC.first);
726   unsigned numArgs = arguments.size();
727   if (getArgs().size() != numArgs)
728     return emitOpError() << "expected " << numArgs << " arguments";
729   for (unsigned i = 0; i < numArgs; i++) {
730     if (getArgs()[i].getType() != arguments[i])
731       return emitOpError() << "expected argument " << i << " to be of type "
732                            << arguments[i];
733   }
734   Type dstType = LLVM::LLVMStructType::getLiteral(
735       getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
736   if (getType() != dstType)
737     return emitOpError("expected destination type is a structure of ")
738            << typeInfoC.second << " elements of type " << typeInfoC.first;
739   return success();
740 }
741 
742 LogicalResult NVVM::LdMatrixOp::verify() {
743   unsigned addressSpace =
744       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
745   if (addressSpace != NVVM::kSharedMemorySpace)
746     return emitOpError("expected source pointer in memory space 3");
747 
748   if (getNum() != 1 && getNum() != 2 && getNum() != 4)
749     return emitOpError("expected num attribute to be 1, 2 or 4");
750 
751   Type i32 = IntegerType::get(getContext(), 32);
752   if (getNum() == 1 && getType() != i32)
753     return emitOpError("expected destination type is i32");
754   if (getNum() == 2 || getNum() == 4) {
755     Type dstType = LLVM::LLVMStructType::getLiteral(
756         getContext(), SmallVector<Type>(getNum(), i32));
757     if (getType() != dstType)
758       return emitOpError("expected destination type is a structure of ")
759              << getNum() << " elements of type i32";
760   }
761   return success();
762 }
763 
764 LogicalResult NVVM::StMatrixOp::verify() {
765   unsigned addressSpace =
766       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
767   if (addressSpace != NVVM::kSharedMemorySpace)
768     return emitOpError("expected source pointer in memory space 3");
769 
770   int numMatrix = getSources().size();
771   if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
772     return emitOpError("expected num attribute to be 1, 2 or 4");
773 
774   return success();
775 }
776 
777 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
778   if (typeA == NVVM::WGMMATypes::tf32)
779     return 8;
780   if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
781     return 16;
782   if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
783     return 32;
784   if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
785     return 32;
786   if (typeA == NVVM::WGMMATypes::b1)
787     return 256;
788   return failure();
789 }
790 
791 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
792                                      NVVM::WGMMATypes typeA,
793                                      NVVM::WGMMATypes typeB) {
794   switch (typeA) {
795   case NVVM::WGMMATypes::f16:
796     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
797         typeB == NVVM::WGMMATypes::f16)
798       return success();
799     break;
800   case NVVM::WGMMATypes::tf32:
801     if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
802       return success();
803     break;
804   case NVVM::WGMMATypes::u8:
805   case NVVM::WGMMATypes::s8:
806     if (typeD == NVVM::WGMMATypes::s32 &&
807         (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
808       return success();
809     break;
810   case NVVM::WGMMATypes::b1:
811     if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
812       return success();
813     break;
814   case NVVM::WGMMATypes::bf16:
815     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
816         typeB == NVVM::WGMMATypes::bf16)
817       return success();
818     break;
819   case NVVM::WGMMATypes::e4m3:
820   case NVVM::WGMMATypes::e5m2:
821     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
822         (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
823       return success();
824     break;
825   case WGMMATypes::f32:
826   case WGMMATypes::s32:
827     llvm_unreachable("unsupported input types");
828     break;
829   }
830   return failure();
831 }
832 
833 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
834   SmallVector<int> allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
835                                72,  80,  88,  96,  104, 112, 120, 128,
836                                136, 144, 152, 160, 168, 176, 184, 192,
837                                200, 208, 216, 224, 232, 240, 248, 256};
838   SmallVector<int> allowedNshort = {8,   16,  24,  32,  48,  64,
839                                     80,  96,  112, 128, 144, 160,
840                                     176, 192, 208, 224, 240, 256};
841   switch (typeA) {
842   case WGMMATypes::f16:
843   case WGMMATypes::tf32:
844   case WGMMATypes::bf16:
845   case WGMMATypes::e4m3:
846   case WGMMATypes::e5m2:
847     if (llvm::is_contained(allowedN, sizeN))
848       return success();
849     break;
850   case WGMMATypes::u8:
851   case WGMMATypes::s8:
852   case WGMMATypes::b1:
853     if (llvm::is_contained(allowedNshort, sizeN))
854       return success();
855     break;
856   case WGMMATypes::f32:
857   case WGMMATypes::s32:
858     llvm_unreachable("unsupported input types");
859     break;
860   }
861   return failure();
862 }
863 
864 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
865   Value outValue = getResults();
866   auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
867   if (!stype)
868     return emitOpError() << "expected results to be struct";
869   int outputSize = stype.getBody().size();
870   WGMMATypes typeD = getTypeD();
871   WGMMATypes typeA = getTypeA();
872   WGMMATypes typeB = getTypeB();
873 
874   for (Type t : stype.getBody()) {
875     if (t != stype.getBody().front())
876       return emitOpError()
877              << "all elements in struct must be same type but there is " << t;
878   }
879 
880   if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
881       typeD != WGMMATypes::s32) {
882     return emitOpError() << "does not support the given output type "
883                          << NVVM::stringifyWGMMATypes(typeD);
884   }
885   if (typeD == WGMMATypes::s32 &&
886       (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
887     return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
888   }
889 
890   if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
891     return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
892                          << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
893                          << NVVM::stringifyWGMMATypes(typeB)
894                          << ", it is not supported.";
895   }
896 
897   // Check M
898   if (getShape().getM() != 64)
899     return emitOpError() << "shape 'm' must be 64";
900 
901   // Check K
902   FailureOr<int> allowedK = getAllowedSizeK(typeA);
903   if (failed(allowedK) || allowedK.value() != getShape().getK())
904     return emitOpError() << "shape 'k' must be " << allowedK.value()
905                          << " for input type "
906                          << NVVM::stringifyWGMMATypes(typeA);
907 
908   // Check N
909   if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
910     return emitOpError() << "has input type "
911                          << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
912                          << getShape().getN() << ", it is not supported.";
913   }
914 
915   // Check transpose (only available for f16/bf16)
916   // Matrices A should be stored in row-major and B in column-major.
917   // Only f16/bf16 matrices can be stored in either column-major or row-major
918   // by setting the tranpose value(imm-trans-a,imm-trans-b) in PTX code.
919   if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
920       (getLayoutA() == mlir::NVVM::MMALayout::col ||
921        getLayoutB() == mlir::NVVM::MMALayout::row)) {
922     return emitOpError()
923            << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
924            << " and layout_b = " << stringifyMMALayout(getLayoutB())
925            << " for input types " << stringifyWGMMATypes(typeA) << " and "
926            << stringifyWGMMATypes(typeB)
927            << " requires transpose. However, this is only supported for: "
928            << stringifyMMATypes(MMATypes::f16) << " and "
929            << stringifyMMATypes(MMATypes::bf16);
930   }
931 
932   // Check result registers
933   int expectedOutput = 0;
934   if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
935     expectedOutput = getShape().getN() / 2;
936   if (typeD == WGMMATypes::f16)
937     expectedOutput = getShape().getN() / 4;
938   if (outputSize != expectedOutput) {
939     return emitOpError() << "results " << expectedOutput
940                          << ", however output struct has " << outputSize
941                          << " elements";
942   }
943   // Check satfinite (only available for s32 accumulator)
944   if (typeD != WGMMATypes::s32 &&
945       getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
946           NVVM::MMAIntOverflow::satfinite) {
947     return emitOpError()
948            << " `satfinite` can be only used with s32 accumulator, however "
949               "the current accumulator is "
950            << NVVM::stringifyWGMMATypes(typeD);
951   }
952 
953   return success();
954 }
955 
956 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
957 
958   int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
959   bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
960 
961   StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
962 
963   int expectedOutputRegisters = 0;
964   if (getTypeD() == WGMMATypes::f16)
965     expectedOutputRegisters = getShape().getN() / 4;
966   else
967     expectedOutputRegisters = getShape().getN() / 2;
968 
969   std::string ptx;
970   llvm::raw_string_ostream ss(ptx);
971 
972   ss << "{\n"
973         ".reg .pred p;\n"
974         "setp.ne.b32 p, $"
975      << ((expectedOutputRegisters * 2) + 2)
976      << ", 0;\n"
977         "wgmma.mma_async.sync.aligned.m"
978      << m << "n" << n << "k" << k << "." << outputTypeName << "."
979      << stringifyWGMMATypes(getTypeA()) << "."
980      << stringifyWGMMATypes(getTypeB());
981   if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
982       NVVM::MMAIntOverflow::satfinite)
983     ss << ".satfinite";
984   ss << " {";
985   int regCnt = 0;
986   for (; regCnt < expectedOutputRegisters; ++regCnt) {
987     ss << "$" << regCnt;
988     if (regCnt != expectedOutputRegisters - 1)
989       ss << ", ";
990   }
991 
992   ss << "},";
993   // Need to map read/write registers correctly.
994   regCnt = (regCnt * 2);
995   ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
996   if (getTypeD() != WGMMATypes::s32) {
997     ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
998   }
999   // Don't add transpose parameters unless needed.
1000   if (isF16) {
1001     ss << ", $" << (regCnt + 5) << ",  $" << (regCnt + 6);
1002   }
1003   ss << ";\n"
1004      << "}\n";
1005   return ptx;
1006 }
1007 
1008 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
1009     RewriterBase &rewriter,
1010     llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1011         &asmValues) {
1012   bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1013   if (getResults())
1014     asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
1015   if (getInouts())
1016     asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
1017   asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
1018   asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
1019   asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1020                        mlir::NVVM::PTXRegisterMod::Read});
1021   if (getTypeD() != WGMMATypes::s32) {
1022     asmValues.push_back(
1023         {makeConstantI32(rewriter,
1024                          getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1025          mlir::NVVM::PTXRegisterMod::Read});
1026     asmValues.push_back(
1027         {makeConstantI32(rewriter,
1028                          getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1029          mlir::NVVM::PTXRegisterMod::Read});
1030   }
1031   if (isF16) {
1032     asmValues.push_back(
1033         {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1034          mlir::NVVM::PTXRegisterMod::Read});
1035     asmValues.push_back(
1036         {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1037          mlir::NVVM::PTXRegisterMod::Read});
1038   }
1039 }
1040 LogicalResult NVVM::FenceProxyOp::verify() {
1041   if (getKind() == NVVM::ProxyKind::TENSORMAP)
1042     return emitOpError() << "tensormap proxy is not a supported proxy kind";
1043   if (getKind() == NVVM::ProxyKind::GENERIC)
1044     return emitOpError() << "generic proxy not a supported proxy kind";
1045   if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1046     return emitOpError() << "async_shared fence requires space attribute";
1047   }
1048   if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1049     return emitOpError() << "only async_shared fence can have space attribute";
1050   }
1051   return success();
1052 }
1053 
1054 LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1055   if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1056     return emitOpError("uni-directional proxies only support generic for "
1057                        "from_proxy attribute");
1058 
1059   if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1060     return emitOpError("uni-directional proxies only support tensormap "
1061                        "for to_proxy attribute");
1062 
1063   return success();
1064 }
1065 
1066 LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1067   if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1068     return emitOpError("uni-directional proxies only support generic for "
1069                        "from_proxy attribute");
1070 
1071   if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1072     return emitOpError("uni-directional proxies only support tensormap "
1073                        "for to_proxy attribute");
1074 
1075   return success();
1076 }
1077 
1078 LogicalResult NVVM::SetMaxRegisterOp::verify() {
1079   if (getRegCount() % 8)
1080     return emitOpError("new register size must be multiple of 8");
1081   if (getRegCount() < 24 || getRegCount() > 256)
1082     return emitOpError("new register size must be in between 24 to 256");
1083   return success();
1084 }
1085 
1086 LogicalResult NVVM::BarrierOp::verify() {
1087   if (getNumberOfThreads() && !getBarrierId())
1088     return emitOpError(
1089         "barrier id is missing, it should be set between 0 to 15");
1090   return success();
1091 }
1092 
1093 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1094                                                                 bool isIm2Col) {
1095   switch (tensorDims) {
1096   case 1:
1097     return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1098   case 2:
1099     return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1100   case 3:
1101     return isIm2Col
1102                ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1103                : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1104   case 4:
1105     return isIm2Col
1106                ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1107                : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1108   case 5:
1109     return isIm2Col
1110                ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1111                : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1112   default:
1113     llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1114   }
1115 }
1116 
1117 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode)                        \
1118   llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1119 
1120 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col)                        \
1121   is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col)                \
1122             : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1123 
1124 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)                       \
1125   [&]() -> auto {                                                              \
1126     switch (dims) {                                                            \
1127     case 1:                                                                    \
1128       return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile);                    \
1129     case 2:                                                                    \
1130       return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile);                    \
1131     case 3:                                                                    \
1132       return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col);                    \
1133     case 4:                                                                    \
1134       return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col);                    \
1135     case 5:                                                                    \
1136       return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col);                    \
1137     default:                                                                   \
1138       llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp.");     \
1139     }                                                                          \
1140   }()
1141 
1142 llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1143     int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1144   using RedTy = NVVM::TMAReduxKind;
1145   switch (kind) {
1146   case RedTy::ADD:
1147     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1148   case RedTy::MIN:
1149     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1150   case RedTy::MAX:
1151     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1152   case RedTy::INC:
1153     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1154   case RedTy::DEC:
1155     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1156   case RedTy::AND:
1157     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1158   case RedTy::OR:
1159     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1160   case RedTy::XOR:
1161     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1162   }
1163   llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1164 }
1165 
1166 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
1167 /// have ConstantRangeAttr.
1168 static void nvvmInferResultRanges(Operation *op, Value result,
1169                                   ArrayRef<::mlir::ConstantIntRanges> argRanges,
1170                                   SetIntRangeFn setResultRanges) {
1171   if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
1172     setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1173                              rangeAttr.getLower(), rangeAttr.getUpper()});
1174   }
1175 }
1176 
1177 //===----------------------------------------------------------------------===//
1178 // NVVMDialect initialization, type parsing, and registration.
1179 //===----------------------------------------------------------------------===//
1180 
1181 // TODO: This should be the llvm.nvvm dialect once this is supported.
1182 void NVVMDialect::initialize() {
1183   addOperations<
1184 #define GET_OP_LIST
1185 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1186       >();
1187   addAttributes<
1188 #define GET_ATTRDEF_LIST
1189 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1190       >();
1191 
1192   // Support unknown operations because not all NVVM operations are
1193   // registered.
1194   allowUnknownOperations();
1195   declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1196   declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1197 }
1198 
1199 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1200                                                     NamedAttribute attr) {
1201   StringAttr attrName = attr.getName();
1202   // Kernel function attribute should be attached to functions.
1203   if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1204     if (!isa<LLVM::LLVMFuncOp>(op)) {
1205       return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1206                              << "' attribute attached to unexpected op";
1207     }
1208   }
1209   // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
1210   // dim
1211   if (attrName == NVVMDialect::getMaxntidAttrName() ||
1212       attrName == NVVMDialect::getReqntidAttrName() ||
1213       attrName == NVVMDialect::getClusterDimAttrName()) {
1214     auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1215     if (!values || values.empty() || values.size() > 3)
1216       return op->emitError()
1217              << "'" << attrName
1218              << "' attribute must be integer array with maximum 3 index";
1219   }
1220   // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
1221   // attribute
1222   if (attrName == NVVMDialect::getMinctasmAttrName() ||
1223       attrName == NVVMDialect::getMaxnregAttrName() ||
1224       attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1225     if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1226       return op->emitError()
1227              << "'" << attrName << "' attribute must be integer constant";
1228   }
1229 
1230   return success();
1231 }
1232 
1233 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1234                                                     unsigned regionIndex,
1235                                                     unsigned argIndex,
1236                                                     NamedAttribute argAttr) {
1237   auto funcOp = dyn_cast<FunctionOpInterface>(op);
1238   if (!funcOp)
1239     return success();
1240 
1241   bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1242   StringAttr attrName = argAttr.getName();
1243   if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1244     if (!isKernel) {
1245       return op->emitError()
1246              << "'" << attrName
1247              << "' attribute must be present only on kernel arguments";
1248     }
1249     if (!isa<UnitAttr>(argAttr.getValue()))
1250       return op->emitError() << "'" << attrName << "' must be a unit attribute";
1251     if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1252       return op->emitError()
1253              << "'" << attrName
1254              << "' attribute requires the argument to also have attribute '"
1255              << LLVM::LLVMDialect::getByValAttrName() << "'";
1256     }
1257   }
1258 
1259   return success();
1260 }
1261 
1262 //===----------------------------------------------------------------------===//
1263 // NVVM target attribute.
1264 //===----------------------------------------------------------------------===//
1265 LogicalResult
1266 NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1267                        int optLevel, StringRef triple, StringRef chip,
1268                        StringRef features, DictionaryAttr flags,
1269                        ArrayAttr files) {
1270   if (optLevel < 0 || optLevel > 3) {
1271     emitError() << "The optimization level must be a number between 0 and 3.";
1272     return failure();
1273   }
1274   if (triple.empty()) {
1275     emitError() << "The target triple cannot be empty.";
1276     return failure();
1277   }
1278   if (chip.empty()) {
1279     emitError() << "The target chip cannot be empty.";
1280     return failure();
1281   }
1282   if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1283         return attr && mlir::isa<StringAttr>(attr);
1284       })) {
1285     emitError() << "All the elements in the `link` array must be strings.";
1286     return failure();
1287   }
1288   return success();
1289 }
1290 
1291 #define GET_OP_CLASSES
1292 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1293 
1294 #define GET_ATTRDEF_CLASSES
1295 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1296