xref: /llvm-project/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18 
19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
21 #include "mlir/Dialect/Utils/StaticValueUtils.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Diagnostics.h"
26 #include "mlir/IR/DialectImplementation.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/IR/Operation.h"
29 #include "mlir/IR/OperationSupport.h"
30 #include "mlir/IR/Types.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/AsmParser/Parser.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/SourceMgr.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include <cassert>
41 #include <optional>
42 #include <string>
43 
44 using namespace mlir;
45 using namespace NVVM;
46 
47 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
48 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
49 
50 //===----------------------------------------------------------------------===//
51 // Printing/parsing for NVVM ops
52 //===----------------------------------------------------------------------===//
53 
54 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
55   p << " " << op->getOperands();
56   if (op->getNumResults() > 0)
57     p << " : " << op->getResultTypes();
58 }
59 
60 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
61 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
62   MLIRContext *context = parser.getContext();
63   auto int32Ty = IntegerType::get(context, 32);
64   auto int1Ty = IntegerType::get(context, 1);
65 
66   SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
67   Type type;
68   return failure(parser.parseOperandList(ops) ||
69                  parser.parseOptionalAttrDict(result.attributes) ||
70                  parser.parseColonType(type) ||
71                  parser.addTypeToList(type, result.types) ||
72                  parser.resolveOperands(ops, {int32Ty, int1Ty},
73                                         parser.getNameLoc(), result.operands));
74 }
75 
76 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
77 
78 // This verifier is shared among the following Ops:
79 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
80 // CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
81 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
82 static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
83                                                      bool isIm2Col,
84                                                      size_t numIm2ColOffsets,
85                                                      Location loc) {
86   if (tensorDims < 1 || tensorDims > 5)
87     return emitError(loc, "expects coordinates between 1 to 5 dimension");
88 
89   // For Im2Col mode, there are two constraints:
90   if (isIm2Col) {
91     // 1. Tensor must always be at least 3-d.
92     if (tensorDims < 3)
93       return emitError(
94           loc,
95           "to use im2col mode, the tensor has to be at least 3-dimensional");
96     // 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
97     if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
98       return emitError(
99           loc, "im2col offsets must be 2 less than number of coordinates");
100   }
101   return success();
102 }
103 
104 LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
105   size_t numIm2ColOffsets = getIm2colOffsets().size();
106   bool isIm2Col = numIm2ColOffsets > 0;
107   return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
108                                          numIm2ColOffsets, getLoc());
109 }
110 
111 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
112   if (getCoordinates().size() > 5)
113     return emitError("Maximum 5 coordinates and dimension is supported.");
114   return success();
115 }
116 
117 LogicalResult CpAsyncOp::verify() {
118   if (getModifier() != LoadCacheModifierKind::CG &&
119       getModifier() != LoadCacheModifierKind::CA)
120     return emitError("Only CG and CA cache modifiers are supported.");
121   if (getSize() != 4 && getSize() != 8 && getSize() != 16)
122     return emitError("expected byte size to be either 4, 8 or 16.");
123   if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16)
124     return emitError("CG cache modifier is only support for 16 bytes copy.");
125   return success();
126 }
127 
128 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
129   size_t numIm2ColOffsets = getIm2colOffsets().size();
130   bool isIm2Col = numIm2ColOffsets > 0;
131   return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
132                                          numIm2ColOffsets, getLoc());
133 }
134 
135 LogicalResult CpAsyncBulkTensorReduceOp::verify() {
136   bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
137   return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
138                                          getLoc());
139 }
140 
141 LogicalResult CvtFloatToTF32Op::verify() {
142   using RndMode = NVVM::FPRoundingMode;
143   switch (getRnd()) {
144   case RndMode::RNA:
145     if (getRelu())
146       return emitError("Relu not supported with rna rounding mode.");
147     break;
148   case RndMode::RN:
149   case RndMode::RZ:
150     if (getSat() != NVVM::SaturationMode::NONE)
151       return emitError(
152           "Saturation mode not supported with rn/rz rounding modes.");
153     break;
154   default:
155     return emitError(
156         "Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.");
157   }
158   return success();
159 }
160 
161 // Given the element type of an operand and whether or not it is an accumulator,
162 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
163 // operand's element type.
164 std::optional<mlir::NVVM::MMATypes>
165 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
166   auto half2Type =
167       LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
168   if (operandElType.isF64())
169     return NVVM::MMATypes::f64;
170   if (operandElType.isF16() || operandElType == half2Type)
171     return NVVM::MMATypes::f16;
172   if (operandElType.isF32() && isAccumulator)
173     return NVVM::MMATypes::f32;
174   if (operandElType.isF32() && !isAccumulator)
175     return NVVM::MMATypes::tf32;
176   if (llvm::isa<IntegerType>(operandElType)) {
177     if (isAccumulator)
178       return NVVM::MMATypes::s32;
179     return std::nullopt;
180   }
181 
182   if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
183     if (structType.getBody().empty())
184       return std::nullopt;
185     return inferOperandMMAType(structType.getBody()[0], isAccumulator);
186   }
187 
188   return std::nullopt;
189 }
190 
191 static bool isInt4PtxType(MMATypes type) {
192   return (type == MMATypes::u4 || type == MMATypes::s4);
193 }
194 
195 static bool isInt8PtxType(MMATypes type) {
196   return (type == MMATypes::u8 || type == MMATypes::s8);
197 }
198 
199 static bool isIntegerPtxType(MMATypes type) {
200   return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
201          type == MMATypes::s32;
202 }
203 
204 MMATypes MmaOp::accumPtxType() {
205   std::optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
206       getODSOperands(2).getTypes().front(), /*isAccum=*/true);
207   assert(val.has_value() && "accumulator PTX type should always be inferrable");
208   return val.value();
209 }
210 
211 MMATypes MmaOp::resultPtxType() {
212   std::optional<mlir::NVVM::MMATypes> val =
213       inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
214   assert(val.has_value() && "result PTX type should always be inferrable");
215   return val.value();
216 }
217 
218 void MmaOp::print(OpAsmPrinter &p) {
219   SmallVector<Type, 4> regTypes;
220   struct OperandFragment {
221     StringRef operandName;
222     StringRef ptxTypeAttr;
223     SmallVector<Value, 4> regs;
224     explicit OperandFragment(StringRef name, StringRef ptxTypeName)
225         : operandName(name), ptxTypeAttr(ptxTypeName) {}
226   };
227 
228   std::array<OperandFragment, 3> frags{
229       OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
230       OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
231       OperandFragment("C", "")};
232   SmallVector<StringRef, 4> ignoreAttrNames{
233       mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
234 
235   for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
236     auto &frag = frags[fragIdx];
237     auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
238     for (auto operandIdx = varOperandSpec.first;
239          operandIdx < varOperandSpec.first + varOperandSpec.second;
240          operandIdx++) {
241       frag.regs.push_back(this->getOperand(operandIdx));
242       if (operandIdx == 0) {
243         regTypes.push_back(this->getOperand(operandIdx).getType());
244       }
245     }
246     std::optional<MMATypes> inferredType =
247         inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
248     if (inferredType)
249       ignoreAttrNames.push_back(frag.ptxTypeAttr);
250   }
251 
252   auto printMmaOperand = [&](const OperandFragment &frag) -> void {
253     p << " " << frag.operandName;
254     p << "[";
255     p.printOperands(frag.regs);
256     p << "] ";
257   };
258 
259   for (const auto &frag : frags) {
260     printMmaOperand(frag);
261   }
262 
263   p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
264 
265   // Print the types of the operands and result.
266   p << " : " << "(";
267   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
268                                              frags[1].regs[0].getType(),
269                                              frags[2].regs[0].getType()},
270                         p);
271   p << ")";
272   p.printArrowTypeList(TypeRange{this->getRes().getType()});
273 }
274 
275 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
276                   ValueRange operandA, ValueRange operandB, ValueRange operandC,
277                   ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
278                   std::optional<MMAIntOverflow> intOverflow,
279                   std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
280                   std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {
281 
282   assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
283   MLIRContext *ctx = builder.getContext();
284   result.addAttribute(
285       "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
286 
287   result.addOperands(operandA);
288   result.addOperands(operandB);
289   result.addOperands(operandC);
290 
291   if (multiplicandPtxTypes) {
292     result.addAttribute("multiplicandAPtxType",
293                         MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
294     result.addAttribute("multiplicandBPtxType",
295                         MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
296   } else {
297     if (auto res = inferOperandMMAType(operandA[0].getType(), false))
298       result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
299     if (auto res = inferOperandMMAType(operandB[0].getType(), false))
300       result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
301   }
302 
303   if (multiplicandLayouts) {
304     result.addAttribute("layoutA",
305                         MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
306     result.addAttribute("layoutB",
307                         MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
308   } else {
309     result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
310     result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
311   }
312 
313   if (intOverflow.has_value())
314     result.addAttribute("intOverflowBehavior",
315                         MMAIntOverflowAttr::get(ctx, *intOverflow));
316   if (b1Op.has_value())
317     result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
318 
319   result.addTypes(resultType);
320   result.addAttribute(
321       MmaOp::getOperandSegmentSizeAttr(),
322       builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
323                                     static_cast<int32_t>(operandB.size()),
324                                     static_cast<int32_t>(operandC.size())}));
325 }
326 
327 // <operation> :=
328 //   A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
329 //   attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
330 //     `->` type($res)
331 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
332   struct OperandFragment {
333     std::optional<MMATypes> elemtype;
334     SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
335     SmallVector<Type> regTypes;
336   };
337 
338   Builder &builder = parser.getBuilder();
339   std::array<OperandFragment, 4> frags;
340 
341   NamedAttrList namedAttributes;
342 
343   // A helper to parse the operand segments.
344   auto parseMmaOperand = [&](StringRef operandName,
345                              OperandFragment &frag) -> LogicalResult {
346     if (parser.parseKeyword(operandName).failed())
347       return failure();
348     if (parser
349             .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
350             .failed())
351       return failure();
352     return success();
353   };
354 
355   // Parse the operand segments.
356   if (parseMmaOperand("A", frags[0]).failed())
357     return failure();
358   if (parseMmaOperand("B", frags[1]).failed())
359     return failure();
360   if (parseMmaOperand("C", frags[2]).failed())
361     return failure();
362 
363   if (parser.parseOptionalAttrDict(namedAttributes).failed())
364     return failure();
365 
366   // Parse the type specification and resolve operands.
367   SmallVector<Type, 3> operandTypes;
368   if (failed(parser.parseColon()))
369     return failure();
370   if (failed(parser.parseLParen()))
371     return failure();
372   if (failed(parser.parseTypeList(operandTypes)))
373     return failure();
374   if (failed(parser.parseRParen()))
375     if (operandTypes.size() != 3)
376       return parser.emitError(
377           parser.getNameLoc(),
378           "expected one type for each operand segment but got " +
379               Twine(operandTypes.size()) + " types");
380   for (const auto &iter : llvm::enumerate(operandTypes)) {
381     auto &frag = frags[iter.index()];
382     frag.regTypes.resize(frag.regs.size(), iter.value());
383     if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
384                                       parser.getNameLoc(), result.operands)))
385       return failure();
386     frag.elemtype =
387         inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
388   }
389 
390   Type resultType;
391   if (parser.parseArrow() || parser.parseType(resultType))
392     return failure();
393   frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
394 
395   std::array<StringRef, 2> names{"multiplicandAPtxType",
396                                  "multiplicandBPtxType"};
397   for (unsigned idx = 0; idx < names.size(); idx++) {
398     const auto &frag = frags[idx];
399     std::optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
400     if (!frag.elemtype.has_value() && !attr.has_value()) {
401       return parser.emitError(
402           parser.getNameLoc(),
403           "attribute " + names[idx] +
404               " is not provided explicitly and cannot be inferred");
405     }
406     if (!attr.has_value())
407       result.addAttribute(
408           names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
409   }
410 
411   result.addTypes(resultType);
412   if (!namedAttributes.empty())
413     result.addAttributes(namedAttributes);
414   result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
415                       builder.getDenseI32ArrayAttr({
416                           static_cast<int32_t>(frags[0].regs.size()),
417                           static_cast<int32_t>(frags[1].regs.size()),
418                           static_cast<int32_t>(frags[2].regs.size()),
419                       }));
420   return success();
421 }
422 
423 LogicalResult MmaOp::verify() {
424   MLIRContext *context = getContext();
425   auto f16Ty = Float16Type::get(context);
426   auto i32Ty = IntegerType::get(context, 32);
427   auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
428   auto f32Ty = Float32Type::get(context);
429   auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
430       context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
431 
432   auto s32x4StructTy =
433       LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
434   auto f32x8StructTy =
435       LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
436   auto f16x2x2StructTy =
437       LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
438   auto f32x4StructTy =
439       LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
440   auto s32x2StructTy =
441       LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
442 
443   std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
444                                   getShapeAttr().getK()};
445 
446   // These variables define the set of allowed data types for matrices A, B, C,
447   // and result.
448   using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
449   using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
450   AllowedShapes allowedShapes;
451   AllowedTypes expectedA;
452   AllowedTypes expectedB;
453   AllowedTypes expectedC;
454   SmallVector<Type> expectedResult;
455 
456   // When M = 16, we just need to calculate the number of 8xk tiles, where
457   // k is a factor that depends on the data type.
458   if (mmaShape[0] == 16) {
459     int64_t kFactor;
460     Type multiplicandFragType;
461     switch (*getMultiplicandAPtxType()) {
462     case MMATypes::tf32:
463       kFactor = 4;
464       multiplicandFragType = i32Ty;
465       expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
466           context, {f32Ty, f32Ty, f32Ty, f32Ty}));
467       break;
468     case MMATypes::bf16:
469       kFactor = 8;
470       multiplicandFragType = i32Ty;
471       expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
472           context, {f32Ty, f32Ty, f32Ty, f32Ty}));
473       break;
474     case MMATypes::f16:
475       kFactor = 8;
476       multiplicandFragType = f16x2Ty;
477       expectedResult.push_back(f16x2x2StructTy);
478       expectedResult.push_back(f32x4StructTy);
479       break;
480     case MMATypes::s4:
481     case MMATypes::u4:
482       kFactor = 32;
483       break;
484     case MMATypes::b1:
485       kFactor = 128;
486       break;
487     case MMATypes::s8:
488     case MMATypes::u8:
489       kFactor = 16;
490       break;
491     default:
492       return emitError("invalid shape or multiplicand type: " +
493                        stringifyEnum(getMultiplicandAPtxType().value()));
494     }
495 
496     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
497       expectedResult.push_back(s32x4StructTy);
498       expectedC.emplace_back(4, i32Ty);
499       multiplicandFragType = i32Ty;
500     } else {
501       expectedC.emplace_back(2, f16x2Ty);
502       expectedC.emplace_back(4, f32Ty);
503     }
504 
505     int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
506     int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
507     expectedA.emplace_back(unitA, multiplicandFragType);
508     expectedB.emplace_back(unitB, multiplicandFragType);
509     allowedShapes.push_back({16, 8, kFactor});
510     allowedShapes.push_back({16, 8, kFactor * 2});
511   }
512 
513   // In the M=8 case, there is only 1 possible case per data type.
514   if (mmaShape[0] == 8) {
515     if (*getMultiplicandAPtxType() == MMATypes::f16) {
516       expectedA.emplace_back(2, f16x2Ty);
517       expectedB.emplace_back(2, f16x2Ty);
518       expectedResult.push_back(f16x2x4StructTy);
519       expectedResult.push_back(f32x8StructTy);
520       expectedC.emplace_back(4, f16x2Ty);
521       expectedC.emplace_back(8, f32Ty);
522       allowedShapes.push_back({8, 8, 4});
523     }
524     if (*getMultiplicandAPtxType() == MMATypes::f64) {
525       Type f64Ty = Float64Type::get(context);
526       expectedA.emplace_back(1, f64Ty);
527       expectedB.emplace_back(1, f64Ty);
528       expectedC.emplace_back(2, f64Ty);
529       // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
530       expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
531           context, SmallVector<Type>(2, f64Ty)));
532       allowedShapes.push_back({8, 8, 4});
533     }
534     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
535       expectedA.push_back({i32Ty});
536       expectedB.push_back({i32Ty});
537       expectedC.push_back({i32Ty, i32Ty});
538       expectedResult.push_back(s32x2StructTy);
539       if (isInt4PtxType(getMultiplicandAPtxType().value()))
540         allowedShapes.push_back({8, 8, 32});
541       if (isInt8PtxType(getMultiplicandAPtxType().value()))
542         allowedShapes.push_back({8, 8, 16});
543       if (getMultiplicandAPtxType().value() == MMATypes::b1)
544         allowedShapes.push_back({8, 8, 128});
545     }
546   }
547 
548   std::string errorMessage;
549   llvm::raw_string_ostream errorStream(errorMessage);
550 
551   // Check that we matched an existing shape/dtype combination.
552   if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
553       !llvm::is_contained(allowedShapes, mmaShape)) {
554     errorStream << "unimplemented variant for MMA shape <";
555     llvm::interleaveComma(mmaShape, errorStream);
556     errorStream << ">";
557     return emitOpError(errorMessage);
558   }
559 
560   // Verify the operand types for segments of A, B, and C operands.
561   std::array<StringRef, 3> operandNames{"A", "B", "C"};
562   for (const auto &iter : llvm::enumerate(
563            SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
564     auto spec = this->getODSOperandIndexAndLength(iter.index());
565     SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
566                                       operand_type_begin() + spec.first +
567                                           spec.second);
568     bool match = llvm::is_contained(iter.value(), operandTySeg);
569 
570     if (!match) {
571       errorStream << "Could not match types for the "
572                   << operandNames[iter.index()]
573                   << " operands; expected one of ";
574       for (const auto &x : iter.value()) {
575         errorStream << x.size() << "x" << x[0] << " ";
576       }
577       errorStream << "but got ";
578       llvm::interleaveComma(operandTySeg, errorStream);
579       return emitOpError(errorMessage);
580     }
581   }
582 
583   // Check the result type
584   if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
585         return expectedResultType == getResult().getType();
586       })) {
587     errorStream
588         << "Could not match allowed types for the result; expected one of ";
589     llvm::interleaveComma(expectedResult, errorStream);
590     errorStream << " but got " << getResult().getType();
591     return emitOpError(errorMessage);
592   }
593 
594   // Ensure that binary MMA variants have a b1 MMA operation defined.
595   if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
596     return emitOpError("op requires " + getB1OpAttrName().strref() +
597                        " attribute");
598   }
599 
600   // Ensure int4/int8 MMA variants specify the accum overflow behavior
601   // attribute.
602   if (isInt4PtxType(*getMultiplicandAPtxType()) ||
603       isInt8PtxType(*getMultiplicandAPtxType())) {
604     if (!getIntOverflowBehavior())
605       return emitOpError("op requires " +
606                          getIntOverflowBehaviorAttrName().strref() +
607                          " attribute");
608   }
609 
610   return success();
611 }
612 
613 LogicalResult ShflOp::verify() {
614   if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
615     return success();
616   auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
617   auto elementType = (type && type.getBody().size() == 2)
618                          ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
619                          : nullptr;
620   if (!elementType || elementType.getWidth() != 1)
621     return emitError("expected return type to be a two-element struct with "
622                      "i1 as the second element");
623   return success();
624 }
625 
626 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
627                                                    NVVM::MMAFrag frag, int nRow,
628                                                    int nCol,
629                                                    MLIRContext *context) {
630   unsigned numberElements = 0;
631   Type elementType;
632   OpBuilder builder(context);
633   Type f16x2 = VectorType::get(2, builder.getF16Type());
634   if (type == NVVM::MMATypes::f16) {
635     elementType = f16x2;
636     if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
637       numberElements = 8;
638     else
639       numberElements = 4;
640   } else if (type == NVVM::MMATypes::f32) {
641     elementType = builder.getF32Type();
642     numberElements = 8;
643   } else if (type == NVVM::MMATypes::tf32) {
644     elementType = builder.getI32Type();
645     numberElements = 4;
646   } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
647     elementType = builder.getI32Type();
648     int parallelSize = 0;
649     if (frag == NVVM::MMAFrag::a)
650       parallelSize = nRow;
651     if (frag == NVVM::MMAFrag::b)
652       parallelSize = nCol;
653 
654     // m == 16 && n == 16 && k == 16
655     if (parallelSize == 16)
656       numberElements = 2;
657     // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
658     else if (parallelSize == 8)
659       numberElements = 1;
660     else if (parallelSize == 32)
661       numberElements = 4;
662   } else if (type == NVVM::MMATypes::s32) {
663     elementType = builder.getI32Type();
664     numberElements = 8;
665   }
666   assert(numberElements != 0 && elementType != nullptr);
667   return std::make_pair(elementType, numberElements);
668 }
669 
670 static std::pair<mlir::Type, unsigned>
671 inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
672                     int k, MLIRContext *context) {
673   int nRow, nCol;
674   if (frag == NVVM::MMAFrag::a) {
675     nRow = m;
676     nCol = k;
677   } else if (frag == NVVM::MMAFrag::b) {
678     nRow = k;
679     nCol = n;
680   } else {
681     nRow = m;
682     nCol = n;
683   }
684   assert(nRow && nCol);
685   return inferMMAType(type, frag, nRow, nCol, context);
686 }
687 
688 LogicalResult NVVM::WMMALoadOp::verify() {
689   unsigned addressSpace =
690       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
691   if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
692       addressSpace != NVVM::kSharedMemorySpace)
693     return emitOpError("expected source pointer in memory "
694                        "space 0, 1, 3");
695 
696   if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
697                                        getEltype(), getFrag()) == 0)
698     return emitOpError() << "invalid attribute combination";
699   std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
700       getEltype(), getFrag(), getM(), getN(), getK(), getContext());
701   Type dstType = LLVM::LLVMStructType::getLiteral(
702       getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
703   if (getType() != dstType)
704     return emitOpError("expected destination type is a structure of ")
705            << typeInfo.second << " elements of type " << typeInfo.first;
706   return success();
707 }
708 
709 LogicalResult NVVM::WMMAStoreOp::verify() {
710   unsigned addressSpace =
711       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
712   if (addressSpace != 0 && addressSpace != NVVM::kGlobalMemorySpace &&
713       addressSpace != NVVM::kSharedMemorySpace)
714     return emitOpError("expected operands to be a source pointer in memory "
715                        "space 0, 1, 3");
716 
717   if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
718                                         getEltype()) == 0)
719     return emitOpError() << "invalid attribute combination";
720   std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
721       getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
722   if (getArgs().size() != typeInfo.second)
723     return emitOpError() << "expected " << typeInfo.second << " data operands";
724   if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
725         return operands.getType() != typeInfo.first;
726       }))
727     return emitOpError() << "expected data operands of type " << typeInfo.first;
728   return success();
729 }
730 
731 LogicalResult NVVM::WMMAMmaOp::verify() {
732   if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
733                                       getLayoutB(), getEltypeA(),
734                                       getEltypeB()) == 0)
735     return emitOpError() << "invalid attribute combination";
736   std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
737       getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
738   std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
739       getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
740   std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
741       getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
742   SmallVector<Type, 32> arguments;
743   arguments.append(typeInfoA.second, typeInfoA.first);
744   arguments.append(typeInfoB.second, typeInfoB.first);
745   arguments.append(typeInfoC.second, typeInfoC.first);
746   unsigned numArgs = arguments.size();
747   if (getArgs().size() != numArgs)
748     return emitOpError() << "expected " << numArgs << " arguments";
749   for (unsigned i = 0; i < numArgs; i++) {
750     if (getArgs()[i].getType() != arguments[i])
751       return emitOpError() << "expected argument " << i << " to be of type "
752                            << arguments[i];
753   }
754   Type dstType = LLVM::LLVMStructType::getLiteral(
755       getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
756   if (getType() != dstType)
757     return emitOpError("expected destination type is a structure of ")
758            << typeInfoC.second << " elements of type " << typeInfoC.first;
759   return success();
760 }
761 
762 LogicalResult NVVM::LdMatrixOp::verify() {
763   unsigned addressSpace =
764       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
765   if (addressSpace != NVVM::kSharedMemorySpace)
766     return emitOpError("expected source pointer in memory space 3");
767 
768   if (getNum() != 1 && getNum() != 2 && getNum() != 4)
769     return emitOpError("expected num attribute to be 1, 2 or 4");
770 
771   Type i32 = IntegerType::get(getContext(), 32);
772   if (getNum() == 1 && getType() != i32)
773     return emitOpError("expected destination type is i32");
774   if (getNum() == 2 || getNum() == 4) {
775     Type dstType = LLVM::LLVMStructType::getLiteral(
776         getContext(), SmallVector<Type>(getNum(), i32));
777     if (getType() != dstType)
778       return emitOpError("expected destination type is a structure of ")
779              << getNum() << " elements of type i32";
780   }
781   return success();
782 }
783 
784 LogicalResult NVVM::StMatrixOp::verify() {
785   unsigned addressSpace =
786       llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
787   if (addressSpace != NVVM::kSharedMemorySpace)
788     return emitOpError("expected source pointer in memory space 3");
789 
790   int numMatrix = getSources().size();
791   if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
792     return emitOpError("expected num attribute to be 1, 2 or 4");
793 
794   return success();
795 }
796 
797 FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
798   if (typeA == NVVM::WGMMATypes::tf32)
799     return 8;
800   if (typeA == NVVM::WGMMATypes::f16 || typeA == NVVM::WGMMATypes::bf16)
801     return 16;
802   if (typeA == NVVM::WGMMATypes::s8 || typeA == NVVM::WGMMATypes::u8)
803     return 32;
804   if (typeA == NVVM::WGMMATypes::e4m3 || typeA == NVVM::WGMMATypes::e5m2)
805     return 32;
806   if (typeA == NVVM::WGMMATypes::b1)
807     return 256;
808   return failure();
809 }
810 
811 LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
812                                      NVVM::WGMMATypes typeA,
813                                      NVVM::WGMMATypes typeB) {
814   switch (typeA) {
815   case NVVM::WGMMATypes::f16:
816     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
817         typeB == NVVM::WGMMATypes::f16)
818       return success();
819     break;
820   case NVVM::WGMMATypes::tf32:
821     if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
822       return success();
823     break;
824   case NVVM::WGMMATypes::u8:
825   case NVVM::WGMMATypes::s8:
826     if (typeD == NVVM::WGMMATypes::s32 &&
827         (typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
828       return success();
829     break;
830   case NVVM::WGMMATypes::b1:
831     if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
832       return success();
833     break;
834   case NVVM::WGMMATypes::bf16:
835     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
836         typeB == NVVM::WGMMATypes::bf16)
837       return success();
838     break;
839   case NVVM::WGMMATypes::e4m3:
840   case NVVM::WGMMATypes::e5m2:
841     if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
842         (typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
843       return success();
844     break;
845   case WGMMATypes::f32:
846   case WGMMATypes::s32:
847     llvm_unreachable("unsupported input types");
848     break;
849   }
850   return failure();
851 }
852 
853 LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
854   SmallVector<int> allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
855                                72,  80,  88,  96,  104, 112, 120, 128,
856                                136, 144, 152, 160, 168, 176, 184, 192,
857                                200, 208, 216, 224, 232, 240, 248, 256};
858   SmallVector<int> allowedNshort = {8,   16,  24,  32,  48,  64,
859                                     80,  96,  112, 128, 144, 160,
860                                     176, 192, 208, 224, 240, 256};
861   switch (typeA) {
862   case WGMMATypes::f16:
863   case WGMMATypes::tf32:
864   case WGMMATypes::bf16:
865   case WGMMATypes::e4m3:
866   case WGMMATypes::e5m2:
867     if (llvm::is_contained(allowedN, sizeN))
868       return success();
869     break;
870   case WGMMATypes::u8:
871   case WGMMATypes::s8:
872   case WGMMATypes::b1:
873     if (llvm::is_contained(allowedNshort, sizeN))
874       return success();
875     break;
876   case WGMMATypes::f32:
877   case WGMMATypes::s32:
878     llvm_unreachable("unsupported input types");
879     break;
880   }
881   return failure();
882 }
883 
884 LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
885   Value outValue = getResults();
886   auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
887   if (!stype)
888     return emitOpError() << "expected results to be struct";
889   int outputSize = stype.getBody().size();
890   WGMMATypes typeD = getTypeD();
891   WGMMATypes typeA = getTypeA();
892   WGMMATypes typeB = getTypeB();
893 
894   for (Type t : stype.getBody()) {
895     if (t != stype.getBody().front())
896       return emitOpError()
897              << "all elements in struct must be same type but there is " << t;
898   }
899 
900   if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
901       typeD != WGMMATypes::s32) {
902     return emitOpError() << "does not support the given output type "
903                          << NVVM::stringifyWGMMATypes(typeD);
904   }
905   if (typeD == WGMMATypes::s32 &&
906       (getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
907     return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
908   }
909 
910   if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
911     return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
912                          << " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
913                          << NVVM::stringifyWGMMATypes(typeB)
914                          << ", it is not supported.";
915   }
916 
917   // Check M
918   if (getShape().getM() != 64)
919     return emitOpError() << "shape 'm' must be 64";
920 
921   // Check K
922   FailureOr<int> allowedK = getAllowedSizeK(typeA);
923   if (failed(allowedK) || allowedK.value() != getShape().getK())
924     return emitOpError() << "shape 'k' must be " << allowedK.value()
925                          << " for input type "
926                          << NVVM::stringifyWGMMATypes(typeA);
927 
928   // Check N
929   if (failed(isAllowedSizeN(getShape().getN(), typeA))) {
930     return emitOpError() << "has input type "
931                          << NVVM::stringifyWGMMATypes(typeA) << " n is set to "
932                          << getShape().getN() << ", it is not supported.";
933   }
934 
935   // Check transpose (only available for f16/bf16)
936   // Matrices A should be stored in row-major and B in column-major.
937   // Only f16/bf16 matrices can be stored in either column-major or row-major
938   // by setting the transpose value(imm-trans-a,imm-trans-b) in PTX code.
939   if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
940       (getLayoutA() == mlir::NVVM::MMALayout::col ||
941        getLayoutB() == mlir::NVVM::MMALayout::row)) {
942     return emitOpError()
943            << "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
944            << " and layout_b = " << stringifyMMALayout(getLayoutB())
945            << " for input types " << stringifyWGMMATypes(typeA) << " and "
946            << stringifyWGMMATypes(typeB)
947            << " requires transpose. However, this is only supported for: "
948            << stringifyMMATypes(MMATypes::f16) << " and "
949            << stringifyMMATypes(MMATypes::bf16);
950   }
951 
952   // Check result registers
953   int expectedOutput = 0;
954   if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
955     expectedOutput = getShape().getN() / 2;
956   if (typeD == WGMMATypes::f16)
957     expectedOutput = getShape().getN() / 4;
958   if (outputSize != expectedOutput) {
959     return emitOpError() << "results " << expectedOutput
960                          << ", however output struct has " << outputSize
961                          << " elements";
962   }
963   // Check satfinite (only available for s32 accumulator)
964   if (typeD != WGMMATypes::s32 &&
965       getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
966           NVVM::MMAIntOverflow::satfinite) {
967     return emitOpError()
968            << " `satfinite` can be only used with s32 accumulator, however "
969               "the current accumulator is "
970            << NVVM::stringifyWGMMATypes(typeD);
971   }
972 
973   return success();
974 }
975 
976 std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
977 
978   int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
979   bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
980 
981   StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
982 
983   int expectedOutputRegisters = 0;
984   if (getTypeD() == WGMMATypes::f16)
985     expectedOutputRegisters = getShape().getN() / 4;
986   else
987     expectedOutputRegisters = getShape().getN() / 2;
988 
989   std::string ptx;
990   llvm::raw_string_ostream ss(ptx);
991 
992   ss << "{\n"
993         ".reg .pred p;\n"
994         "setp.ne.b32 p, $"
995      << ((expectedOutputRegisters * 2) + 2)
996      << ", 0;\n"
997         "wgmma.mma_async.sync.aligned.m"
998      << m << "n" << n << "k" << k << "." << outputTypeName << "."
999      << stringifyWGMMATypes(getTypeA()) << "."
1000      << stringifyWGMMATypes(getTypeB());
1001   if (getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
1002       NVVM::MMAIntOverflow::satfinite)
1003     ss << ".satfinite";
1004   ss << " {";
1005   int regCnt = 0;
1006   for (; regCnt < expectedOutputRegisters; ++regCnt) {
1007     ss << "$" << regCnt;
1008     if (regCnt != expectedOutputRegisters - 1)
1009       ss << ", ";
1010   }
1011 
1012   ss << "},";
1013   // Need to map read/write registers correctly.
1014   regCnt = (regCnt * 2);
1015   ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
1016   if (getTypeD() != WGMMATypes::s32) {
1017     ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
1018   }
1019   // Don't add transpose parameters unless needed.
1020   if (isF16) {
1021     ss << ", $" << (regCnt + 5) << ",  $" << (regCnt + 6);
1022   }
1023   ss << ";\n"
1024      << "}\n";
1025   return ptx;
1026 }
1027 
1028 void NVVM::WgmmaMmaAsyncOp::getAsmValues(
1029     RewriterBase &rewriter,
1030     llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1031         &asmValues) {
1032   bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
1033   if (getResults())
1034     asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
1035   if (getInouts())
1036     asmValues.push_back({getInouts(), mlir::NVVM::PTXRegisterMod::ReadWrite});
1037   asmValues.push_back({getDescriptorA(), mlir::NVVM::PTXRegisterMod::Read});
1038   asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
1039   asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
1040                        mlir::NVVM::PTXRegisterMod::Read});
1041   if (getTypeD() != WGMMATypes::s32) {
1042     asmValues.push_back(
1043         {makeConstantI32(rewriter,
1044                          getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1045          mlir::NVVM::PTXRegisterMod::Read});
1046     asmValues.push_back(
1047         {makeConstantI32(rewriter,
1048                          getScaleB() == NVVM::WGMMAScaleIn::neg ? -1 : 1),
1049          mlir::NVVM::PTXRegisterMod::Read});
1050   }
1051   if (isF16) {
1052     asmValues.push_back(
1053         {makeConstantI32(rewriter, static_cast<int>(getLayoutA())),
1054          mlir::NVVM::PTXRegisterMod::Read});
1055     asmValues.push_back(
1056         {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
1057          mlir::NVVM::PTXRegisterMod::Read});
1058   }
1059 }
1060 LogicalResult NVVM::FenceProxyOp::verify() {
1061   if (getKind() == NVVM::ProxyKind::TENSORMAP)
1062     return emitOpError() << "tensormap proxy is not a supported proxy kind";
1063   if (getKind() == NVVM::ProxyKind::GENERIC)
1064     return emitOpError() << "generic proxy not a supported proxy kind";
1065   if (getKind() == NVVM::ProxyKind::async_shared && !getSpace().has_value()) {
1066     return emitOpError() << "async_shared fence requires space attribute";
1067   }
1068   if (getKind() != NVVM::ProxyKind::async_shared && getSpace().has_value()) {
1069     return emitOpError() << "only async_shared fence can have space attribute";
1070   }
1071   return success();
1072 }
1073 
1074 LogicalResult NVVM::FenceProxyAcquireOp::verify() {
1075   if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1076     return emitOpError("uni-directional proxies only support generic for "
1077                        "from_proxy attribute");
1078 
1079   if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1080     return emitOpError("uni-directional proxies only support tensormap "
1081                        "for to_proxy attribute");
1082 
1083   return success();
1084 }
1085 
1086 LogicalResult NVVM::FenceProxyReleaseOp::verify() {
1087   if (getFromProxy() != NVVM::ProxyKind::GENERIC)
1088     return emitOpError("uni-directional proxies only support generic for "
1089                        "from_proxy attribute");
1090 
1091   if (getToProxy() != NVVM::ProxyKind::TENSORMAP)
1092     return emitOpError("uni-directional proxies only support tensormap "
1093                        "for to_proxy attribute");
1094 
1095   return success();
1096 }
1097 
1098 LogicalResult NVVM::SetMaxRegisterOp::verify() {
1099   if (getRegCount() % 8)
1100     return emitOpError("new register size must be multiple of 8");
1101   if (getRegCount() < 24 || getRegCount() > 256)
1102     return emitOpError("new register size must be in between 24 to 256");
1103   return success();
1104 }
1105 
1106 LogicalResult NVVM::BarrierOp::verify() {
1107   if (getNumberOfThreads() && !getBarrierId())
1108     return emitOpError(
1109         "barrier id is missing, it should be set between 0 to 15");
1110   return success();
1111 }
1112 
1113 #define CP_ASYNC_ID_IMPL(mod, size, suffix)                                    \
1114   llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
1115 
1116 #define GET_CP_ASYNC_ID(mod, size, has_cpsize)                                 \
1117   has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
1118 
1119 llvm::Intrinsic::ID
1120 CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
1121                                  llvm::SmallVector<llvm::Value *> &args) {
1122   llvm::Intrinsic::ID id;
1123 
1124   auto cpAsyncOp = cast<NVVM::CpAsyncOp>(op);
1125   bool hasCpSize = cpAsyncOp.getCpSize() ? true : false;
1126   switch (cpAsyncOp.getSize()) {
1127   case 4:
1128     id = GET_CP_ASYNC_ID(ca, 4, hasCpSize);
1129     break;
1130   case 8:
1131     id = GET_CP_ASYNC_ID(ca, 8, hasCpSize);
1132     break;
1133   case 16:
1134     id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG)
1135              ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
1136              : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
1137     break;
1138   default:
1139     llvm_unreachable("Invalid copy size in CpAsyncOp.");
1140   }
1141 
1142   // Fill the Intrinsic Args
1143   args.push_back(mt.lookupValue(cpAsyncOp.getDst()));
1144   args.push_back(mt.lookupValue(cpAsyncOp.getSrc()));
1145   if (hasCpSize)
1146     args.push_back(mt.lookupValue(cpAsyncOp.getCpSize()));
1147 
1148   return id;
1149 }
1150 
1151 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
1152                                                                 bool isIm2Col) {
1153   switch (tensorDims) {
1154   case 1:
1155     return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
1156   case 2:
1157     return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
1158   case 3:
1159     return isIm2Col
1160                ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
1161                : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
1162   case 4:
1163     return isIm2Col
1164                ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
1165                : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
1166   case 5:
1167     return isIm2Col
1168                ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
1169                : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
1170   default:
1171     llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
1172   }
1173 }
1174 
1175 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode)                        \
1176   llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1177 
1178 #define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col)                        \
1179   is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col)                \
1180             : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1181 
1182 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)                       \
1183   [&]() -> auto {                                                              \
1184     switch (dims) {                                                            \
1185     case 1:                                                                    \
1186       return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile);                    \
1187     case 2:                                                                    \
1188       return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile);                    \
1189     case 3:                                                                    \
1190       return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col);                    \
1191     case 4:                                                                    \
1192       return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col);                    \
1193     case 5:                                                                    \
1194       return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col);                    \
1195     default:                                                                   \
1196       llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp.");     \
1197     }                                                                          \
1198   }()
1199 
1200 llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1201     int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1202   using RedTy = NVVM::TMAReduxKind;
1203   switch (kind) {
1204   case RedTy::ADD:
1205     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1206   case RedTy::MIN:
1207     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1208   case RedTy::MAX:
1209     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1210   case RedTy::INC:
1211     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1212   case RedTy::DEC:
1213     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1214   case RedTy::AND:
1215     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1216   case RedTy::OR:
1217     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1218   case RedTy::XOR:
1219     return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1220   }
1221   llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1222 }
1223 
1224 llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1225                                                      NVVM::SaturationMode sat,
1226                                                      bool hasRelu) {
1227   using RndMode = NVVM::FPRoundingMode;
1228   switch (rnd) {
1229   case RndMode::RN:
1230     return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu
1231                    : llvm::Intrinsic::nvvm_f2tf32_rn;
1232   case RndMode::RZ:
1233     return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu
1234                    : llvm::Intrinsic::nvvm_f2tf32_rz;
1235   case RndMode::RNA:
1236     return (sat == NVVM::SaturationMode::SATFINITE)
1237                ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
1238                : llvm::Intrinsic::nvvm_f2tf32_rna;
1239   default:
1240     llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
1241   }
1242 }
1243 
1244 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
1245 /// have ConstantRangeAttr.
1246 static void nvvmInferResultRanges(Operation *op, Value result,
1247                                   ArrayRef<::mlir::ConstantIntRanges> argRanges,
1248                                   SetIntRangeFn setResultRanges) {
1249   if (auto rangeAttr = op->getAttrOfType<LLVM::ConstantRangeAttr>("range")) {
1250     setResultRanges(result, {rangeAttr.getLower(), rangeAttr.getUpper(),
1251                              rangeAttr.getLower(), rangeAttr.getUpper()});
1252   }
1253 }
1254 
1255 //===----------------------------------------------------------------------===//
1256 // NVVMDialect initialization, type parsing, and registration.
1257 //===----------------------------------------------------------------------===//
1258 
1259 // TODO: This should be the llvm.nvvm dialect once this is supported.
1260 void NVVMDialect::initialize() {
1261   addOperations<
1262 #define GET_OP_LIST
1263 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1264       >();
1265   addAttributes<
1266 #define GET_ATTRDEF_LIST
1267 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1268       >();
1269 
1270   // Support unknown operations because not all NVVM operations are
1271   // registered.
1272   allowUnknownOperations();
1273   declarePromisedInterface<ConvertToLLVMPatternInterface, NVVMDialect>();
1274   declarePromisedInterface<gpu::TargetAttrInterface, NVVMTargetAttr>();
1275 }
1276 
1277 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
1278                                                     NamedAttribute attr) {
1279   StringAttr attrName = attr.getName();
1280   // Kernel function attribute should be attached to functions.
1281   if (attrName == NVVMDialect::getKernelFuncAttrName()) {
1282     if (!isa<LLVM::LLVMFuncOp>(op)) {
1283       return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
1284                              << "' attribute attached to unexpected op";
1285     }
1286   }
1287   // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
1288   // dim
1289   if (attrName == NVVMDialect::getMaxntidAttrName() ||
1290       attrName == NVVMDialect::getReqntidAttrName() ||
1291       attrName == NVVMDialect::getClusterDimAttrName()) {
1292     auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
1293     if (!values || values.empty() || values.size() > 3)
1294       return op->emitError()
1295              << "'" << attrName
1296              << "' attribute must be integer array with maximum 3 index";
1297   }
1298   // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
1299   // attribute
1300   if (attrName == NVVMDialect::getMinctasmAttrName() ||
1301       attrName == NVVMDialect::getMaxnregAttrName() ||
1302       attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
1303     if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
1304       return op->emitError()
1305              << "'" << attrName << "' attribute must be integer constant";
1306   }
1307 
1308   return success();
1309 }
1310 
1311 LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
1312                                                     unsigned regionIndex,
1313                                                     unsigned argIndex,
1314                                                     NamedAttribute argAttr) {
1315   auto funcOp = dyn_cast<FunctionOpInterface>(op);
1316   if (!funcOp)
1317     return success();
1318 
1319   bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
1320   StringAttr attrName = argAttr.getName();
1321   if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
1322     if (!isKernel) {
1323       return op->emitError()
1324              << "'" << attrName
1325              << "' attribute must be present only on kernel arguments";
1326     }
1327     if (!isa<UnitAttr>(argAttr.getValue()))
1328       return op->emitError() << "'" << attrName << "' must be a unit attribute";
1329     if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName())) {
1330       return op->emitError()
1331              << "'" << attrName
1332              << "' attribute requires the argument to also have attribute '"
1333              << LLVM::LLVMDialect::getByValAttrName() << "'";
1334     }
1335   }
1336 
1337   return success();
1338 }
1339 
1340 //===----------------------------------------------------------------------===//
1341 // NVVM target attribute.
1342 //===----------------------------------------------------------------------===//
1343 LogicalResult
1344 NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1345                        int optLevel, StringRef triple, StringRef chip,
1346                        StringRef features, DictionaryAttr flags,
1347                        ArrayAttr files) {
1348   if (optLevel < 0 || optLevel > 3) {
1349     emitError() << "The optimization level must be a number between 0 and 3.";
1350     return failure();
1351   }
1352   if (triple.empty()) {
1353     emitError() << "The target triple cannot be empty.";
1354     return failure();
1355   }
1356   if (chip.empty()) {
1357     emitError() << "The target chip cannot be empty.";
1358     return failure();
1359   }
1360   if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
1361         return attr && mlir::isa<StringAttr>(attr);
1362       })) {
1363     emitError() << "All the elements in the `link` array must be strings.";
1364     return failure();
1365   }
1366   return success();
1367 }
1368 
1369 #define GET_OP_CLASSES
1370 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
1371 
1372 #define GET_ATTRDEF_CLASSES
1373 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
1374