xref: /llvm-project/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the NVGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/OpImplementation.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/IR/Verifier.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 
30 using namespace mlir;
31 using namespace mlir::nvgpu;
32 
33 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
34 
35 void nvgpu::NVGPUDialect::initialize() {
36   addTypes<
37 #define GET_TYPEDEF_LIST
38 #include "mlir/Dialect/NVGPU/IR/NVGPUTypes.cpp.inc"
39       >();
40   addAttributes<
41 #define GET_ATTRDEF_LIST
42 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
43       >();
44   addOperations<
45 #define GET_OP_LIST
46 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
47       >();
48 }
49 
50 bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
51   if (!memorySpace)
52     return false;
53   if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
54     return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
55   if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
56     return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
57   return false;
58 }
59 
60 bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
61   Attribute memorySpace = type.getMemorySpace();
62   return isSharedMemoryAddressSpace(memorySpace);
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // NVGPU_DeviceAsyncCopyOp
67 //===----------------------------------------------------------------------===//
68 
69 LogicalResult DeviceAsyncCopyOp::verify() {
70   auto srcMemref = llvm::cast<MemRefType>(getSrc().getType());
71   auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
72 
73   if (!srcMemref.isLastDimUnitStride())
74     return emitError("source memref most minor dim must have unit stride");
75   if (!dstMemref.isLastDimUnitStride())
76     return emitError("destination memref most minor dim must have unit stride");
77   if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
78     return emitError()
79            << "destination memref must have a memory space attribute of "
80               "IntegerAttr("
81            << NVGPUDialect::kSharedMemoryAddressSpace
82            << ") or gpu::AddressSpaceAttr(Workgroup)";
83   if (dstMemref.getElementType() != srcMemref.getElementType())
84     return emitError("source and destination must have the same element type");
85   if (size_t(srcMemref.getRank()) != getSrcIndices().size())
86     return emitOpError() << "expected " << srcMemref.getRank()
87                          << " source indices, got " << getSrcIndices().size();
88   if (size_t(dstMemref.getRank()) != getDstIndices().size())
89     return emitOpError() << "expected " << dstMemref.getRank()
90                          << " destination indices, got "
91                          << getDstIndices().size();
92   int64_t dstElements = getDstElements().getZExtValue();
93   int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8;
94   if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) {
95     unsigned dstWidth = dstMemref.getElementTypeBitWidth();
96     InFlightDiagnostic diag = emitError();
97     diag << "Requested copy elements is " << dstElements << " with width "
98          << dstMemref.getElementTypeBitWidth()
99          << ". But copy elements could be one of ";
100     if ((32 / dstWidth) > 0)
101       diag << (32 / dstWidth) << ", ";
102     if ((64 / dstWidth) > 0)
103       diag << (64 / dstWidth) << ", ";
104     if ((128 / dstWidth) > 0)
105       diag << (128 / dstWidth) << ".";
106     return diag;
107   }
108   if (getBypassL1().has_value()) {
109     int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth();
110     if (getBypassL1().value() && sizeInBytes != 16) {
111       return emitOpError() << "bypassL1 does not satify alignment for "
112                            << dstMemref << " with destination element "
113                            << dstElements
114                            << ". Unset bypassL1, or set "
115                               "destination element to "
116                            << req;
117     }
118   }
119   return success();
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // NVGPU_MmaSyncOp
124 //===----------------------------------------------------------------------===//
125 void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
126                       ::mlir::OperationState &odsState, Value matrixA,
127                       Value matrixB, Value matrixC, ArrayAttr mmaShape) {
128   build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
129         mmaShape, UnitAttr());
130 }
131 
132 void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
133                       ::mlir::OperationState &odsState, Value matrixA,
134                       Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
135                       bool tf32Enabled) {
136   build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
137         odsBuilder.getI64ArrayAttr(mmaShape),
138         tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
139 }
140 
141 /// Performs verification for MmaSyncOp and MmaSparseSyncOp.
142 static LogicalResult verifyMmaSyncOp(Operation *op,
143                                      TypedValue<VectorType> matrixA,
144                                      TypedValue<VectorType> matrixB,
145                                      TypedValue<VectorType> matrixC,
146                                      const std::array<int64_t, 3> &mmaShape,
147                                      bool tf32Enabled, bool sparse = false) {
148 
149   // The verification for mma.sync covering various shapes and data types is
150   // based on the fundamental tensor core shape.
151 
152   // "Fundamental" tensor core shapes:
153   //  - For F32 (TF32), F16, S8, and S4 data
154   //    types the fundamental tensor core operation is of shape 8-by-8-by-128b.
155   //  - F64 is an exception and is of shape 8-by-8-by-256b.
156   int64_t shapeM = 8;
157   int64_t shapeN = 8;
158   int64_t shapeK; // set based on data type (128b for all data types except F64)
159 
160   // Number of elements A, B, and C per thread per fundamental tensor core tile
161   int64_t numElementA;    // set based on data type (32b except F64)
162   int64_t numElementB;    // set based on data type (32b except F64)
163   int64_t numElementC{2}; // two accumulator elements per fundamental tile
164 
165   // nvgpu.mma.sync vector operands (per thread)
166   auto aVector = matrixA.getType();
167   auto bVector = matrixB.getType();
168   auto cVector = matrixC.getType();
169 
170   // vector shapes
171   ArrayRef<int64_t> aShape = aVector.getShape();
172   ArrayRef<int64_t> bShape = bVector.getShape();
173   ArrayRef<int64_t> cShape = cVector.getShape();
174 
175   // vector element type
176   Type aType = aVector.getElementType();
177 
178   // Certain data types are not allowed in sparse mode.
179   if (sparse && aType.isF64())
180     return op->emitError() << "f64 is not supported for sparse mode";
181 
182   if (aType.isF64()) {
183     // exception to 8-by-8-128b fundamental tensor core tile size
184     shapeK = 4;
185     numElementA = 1;
186     numElementB = 1;
187   } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
188              aType.isInteger(8) || aType.isInteger(4)) {
189     // 8-by-8-128b fundamental tensor core tile size
190     int operandBitwidth = aType.getIntOrFloatBitWidth();
191     shapeK = 128 / operandBitwidth; // 128b wide shapeK
192 
193     numElementA = 32 / operandBitwidth; // 32b wide operand A
194     numElementB = 32 / operandBitwidth; // 32b wide operand B
195   } else {
196     return op->emitError()
197            << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
198               "supported by "
199            << op->getName();
200   }
201 
202   //
203   // Basic verification
204   //
205 
206   if (aShape.size() != 2) {
207     return op->emitError() << "matrixA must be 2 dimensional vector";
208   }
209 
210   if (bShape.size() != 2) {
211     return op->emitError() << "matrixB must be 2 dimensional vector";
212   }
213 
214   if (cShape.size() != 2) {
215     return op->emitError() << "matrixC must be 2 dimensional vector";
216   }
217 
218   auto [m, n, k] = mmaShape;
219 
220   // verify warp-wide size for vector a
221   int64_t sparseFactor = sparse ? 2 : 1;
222   if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor)
223     return op->emitOpError()
224            << "expected " << m * k << " warp-wide matrix A elements";
225 
226   // verify warp-wide size for vector b
227   if (bShape[0] * bShape[1] * kWarpSize != k * n)
228     return op->emitOpError()
229            << "expected " << k * n << " warp-wide matrix B elements";
230 
231   // verify warp-wide size for vector c
232   if (cShape[0] * cShape[1] * kWarpSize != m * n)
233     return op->emitOpError()
234            << "expected " << m * n << " warp-wide matrix C elements";
235 
236   // verify tf32 tensor cores are enabled for only F32 datatype
237   if (tf32Enabled && !(aType.isF32()))
238     return op->emitOpError()
239            << "expected tf32 tensor cores only for F32 operands";
240 
241   //
242   // Extended verification
243   //
244 
245   // tiles of fundamental tensor core operations
246   int64_t mTile = m / shapeM;
247   int64_t nTile = n / shapeN;
248   int64_t kTile = k / shapeK;
249 
250   // verify shape of aVector
251   if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
252       (aShape[1] != numElementA))
253     return op->emitOpError() << "expected matrix A to be shaped ("
254                              << mTile * kTile << " x " << numElementA << ")";
255 
256   // verify shape of bVector
257   if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
258     return op->emitOpError() << "expected matrix B to be shaped ("
259                              << kTile * nTile << " x " << numElementB << ")";
260 
261   // verify shape of cVector
262   if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
263     return op->emitOpError() << "expected matrix C to be shaped ("
264                              << mTile * nTile << " x " << numElementC << ")";
265 
266   return success();
267 }
268 
269 LogicalResult MmaSyncOp::verify() {
270   return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
271                          getMatrixC(), getMmaShapeAsArray(),
272                          getOperation()->hasAttr(getTf32EnabledAttrName()));
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // NVGPU_MmaSparseSyncOp
277 //===----------------------------------------------------------------------===//
278 void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
279                             ::mlir::OperationState &odsState, Value matrixA,
280                             Value matrixB, Value matrixC, Value sparseMetadata,
281                             ArrayRef<int64_t> mmaShape) {
282   build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
283         sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
284 }
285 
286 LogicalResult MmaSparseSyncOp::verify() {
287   unsigned sparsitySelector = getSparsitySelector();
288   if (sparsitySelector > 1)
289     return emitOpError() << "sparsity selector should be 0 or 1";
290   return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
291                          getMatrixC(), getMmaShapeAsArray(),
292                          getOperation()->hasAttr(getTf32EnabledAttrName()),
293                          true);
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // NVGPU_LdMatrixOp
298 //===----------------------------------------------------------------------===//
299 LogicalResult LdMatrixOp::verify() {
300 
301   // ldmatrix reads data from source in shared memory
302   auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType());
303 
304   // ldmatrix writes data to result/destination in vector registers
305   auto resVector = llvm::cast<VectorType>(getRes().getType());
306 
307   // vector register shape, element type, and bitwidth
308   ArrayRef<int64_t> resShape = resVector.getShape();
309   Type resType = resVector.getElementType();
310   int64_t elementBitWidth = resType.getIntOrFloatBitWidth();
311 
312   // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread
313   int64_t numElementsPer32b = 32 / elementBitWidth;
314 
315   // number of 8-by-8 tiles
316   int64_t numTiles = getNumTiles();
317 
318   // transpose elements in vector registers at 16b granularity when true
319   bool isTranspose = getTranspose();
320 
321   //
322   // verification
323   //
324 
325   if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref))
326     return emitError()
327            << "expected nvgpu.ldmatrix srcMemref must have a memory space "
328               "attribute of IntegerAttr("
329            << NVGPUDialect::kSharedMemoryAddressSpace
330            << ") or gpu::AddressSpaceAttr(Workgroup)";
331   if (elementBitWidth > 32)
332     return emitError() << "nvgpu.ldmatrix works for 32b or lower";
333   if (isTranspose && !(elementBitWidth == 16))
334     return emitError()
335            << "nvgpu.ldmatrix transpose works only at 16b granularity";
336   if (resShape.size() != 2) {
337     return emitError() << "results must be 2 dimensional vector";
338   }
339   if (!(resShape[1] == numElementsPer32b))
340     return emitError() << "expected vector register shape[1] = "
341                        << numElementsPer32b;
342   if (!(resShape[0] == numTiles))
343     return emitError()
344            << "expected vector register shape[0] and numTiles to match";
345 
346   return success();
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // NVGPU_TmaAsyncLoadOp
351 //===----------------------------------------------------------------------===//
352 
353 std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
354     Operation *op, nvgpu::TensorMapDescriptorType descType,
355     std::optional<MemRefType> memrefType = std::nullopt) {
356   MemRefType descMemref = descType.getTensor();
357   // Limitation
358   if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
359     return op->emitError() << "Interleave options are not supported yet.";
360 
361   // Address space check for shared memory check
362   if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
363     return op->emitError() << "the tensor map descriptor has incorrect address "
364                               "space, it must be shared memory address space.";
365   }
366   // Support only static shape for the time being
367   if (!descMemref.hasStaticShape())
368     return op->emitError() << "the tensor map descriptor must be static shaped";
369 
370   for (auto dim : descMemref.getShape()) {
371     if (dim <= 0 || dim > kMaxTMADimension) {
372       return op->emitError() << "the tensor map descriptor must have "
373                                 "dimensions between 1 and "
374                              << kMaxTMADimension << " but it is " << dim;
375     }
376   }
377   if (descMemref.getRank() > 1 &&
378       descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) {
379     unsigned lastDimensionByte =
380         descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
381     if (lastDimensionByte != kMaxTMALastdimByte)
382       return op->emitError() << "the tensormap descriptor must have last "
383                                 "dimension of "
384                              << kMaxTMALastdimByte << " bytes but it is "
385                              << lastDimensionByte << " bytes";
386   }
387 
388   // No verification if memref type is not provided
389   if (!memrefType.has_value())
390     return std::nullopt;
391 
392   MemRefType dstMemref = memrefType.value();
393 
394   // Check element type
395   if (descMemref.getElementType() != dstMemref.getElementType()) {
396     return op->emitError() << "the element type of tensor map descriptor and "
397                               "memref must be same";
398   }
399 
400   if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
401     return op->emitError() << "the destination memref has incorrect address "
402                               "space, it must be shared memory address space.";
403   }
404   if (!dstMemref.hasStaticShape())
405     return op->emitError() << "the destination memref must be static shaped";
406 
407   if (dstMemref.getRank() != descMemref.getRank()) {
408     return op->emitError() << "the shape of tensor map descriptor and "
409                               "memref must have same rank";
410   }
411   if (!descMemref.getShape().equals(dstMemref.getShape())) {
412     return op->emitError() << "memref and tensor map shapes mismatch "
413                            << descMemref << " != " << dstMemref;
414   }
415 
416   return std::nullopt;
417 }
418 
419 LogicalResult TmaAsyncLoadOp::verify() {
420   std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
421       *this, getTensorMapDescriptor().getType(), getDst().getType());
422   if (error.has_value())
423     return error.value();
424 
425   if (getCoordinates().size() > kMaxTMATensorDimension) {
426     return emitError() << "Maximum " << kMaxTMATensorDimension
427                        << " coordinates are supported.";
428   }
429   if (getCoordinates().size() !=
430       size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
431     return emitError() << "number of coordinates do not match with the rank of "
432                           "tensor descriptor map.";
433   }
434 
435   return success();
436 }
437 
438 //===----------------------------------------------------------------------===//
439 // NVGPU_TmaAsyncStoreOp
440 //===----------------------------------------------------------------------===//
441 
442 LogicalResult TmaAsyncStoreOp::verify() {
443   std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
444       *this, getTensorMapDescriptor().getType(), getSrc().getType());
445   if (error.has_value())
446     return error.value();
447 
448   if (getCoordinates().size() > kMaxTMATensorDimension) {
449     return emitError() << "Maximum " << kMaxTMATensorDimension
450                        << " coordinates are supported.";
451   }
452   if (getCoordinates().size() !=
453       size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
454     return emitError() << "number of coordinates do not match with the rank of "
455                           "tensor descriptor map.";
456   }
457 
458   return success();
459 }
460 
461 LogicalResult TmaCreateDescriptorOp::verify() {
462   if (getBoxDimensions().size() > kMaxTMATensorDimension) {
463     return emitError() << "Maximum " << kMaxTMATensorDimension
464                        << " coordinates are supported.";
465   }
466 
467   std::optional<InFlightDiagnostic> error =
468       verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
469   if (error.has_value())
470     return error.value();
471 
472   return success();
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // NVGPU_WarpgroupGenerateDescriptorOp
477 //===----------------------------------------------------------------------===//
478 
479 LogicalResult WarpgroupGenerateDescriptorOp::verify() {
480   std::optional<InFlightDiagnostic> error =
481       verifyTmaDescriptorWithMemref(*this, getTensorMap().getType());
482   if (error.has_value())
483     return error.value();
484 
485   if (getTensorMap().getType().getSwizzle() !=
486       TensorMapSwizzleKind::SWIZZLE_128B) {
487     return emitError() << "supports only "
488                        << stringifyTensorMapSwizzleKind(
489                               TensorMapSwizzleKind::SWIZZLE_128B)
490                        << " is supported for the time being";
491   }
492 
493   if (getTensorMap().getType().getInterleave() !=
494       TensorMapInterleaveKind::INTERLEAVE_NONE) {
495     return emitError() << "supports only "
496                        << stringifyTensorMapInterleaveKind(
497                               TensorMapInterleaveKind::INTERLEAVE_NONE)
498                        << " is supported for the time being";
499   }
500 
501   return success();
502 }
503 
504 //===----------------------------------------------------------------------===//
505 // WarpgroupMmaOp
506 //===----------------------------------------------------------------------===//
507 
508 LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
509   // F32 += F16 + F16
510   // F16 += F16 + F16
511   if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16()))
512     return success();
513   // F32 += TF32 + TF32
514   if (typeA.isTF32() && typeD.isF32() && typeB.isTF32())
515     return success();
516   // s32 += i8 + i8
517   if (typeA.isInteger(16) && typeB.isInteger(16) && typeD.isInteger(32))
518     return success();
519   // s32 += i1 + i1
520   if (typeA.isInteger(1) && typeB.isInteger(1) && typeD.isInteger(32))
521     return success();
522   // F32 += BF16 + BF16
523   // F16 += BF16 + BF16
524   if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16()))
525     return success();
526   // F16 += f8 + f8
527   // F32 += f8 + f8
528   if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) &&
529       isa<Float8E5M2Type, Float8E4M3FNType>(typeB) &&
530       (typeD.isF32() || typeD.isF16()))
531     return success();
532 
533   return failure();
534 }
535 
536 LogicalResult isAllowedSizeM(int sizeM) {
537   if (sizeM % kWgmmaSizeM)
538     return failure();
539   return success();
540 }
541 
542 LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
543   SmallVector<int> allowedN = {8,   16,  24,  32,  40,  48,  56,  64,
544                                72,  80,  88,  96,  104, 112, 120, 128,
545                                136, 144, 152, 160, 168, 176, 184, 192,
546                                200, 208, 216, 224, 232, 240, 248, 256};
547   SmallVector<int> allowedNshort = {8,   16,  24,  32,  48,  64,
548                                     80,  96,  112, 128, 144, 160,
549                                     176, 192, 208, 224, 240, 256};
550   if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
551       isa<Float8E5M2Type, Float8E4M3FNType>(typeA))
552     if (llvm::is_contained(allowedN, sizeN))
553       return success();
554 
555   if (typeA.isInteger(8) || typeA.isInteger(1))
556     if (llvm::is_contained(allowedNshort, sizeN))
557       return success();
558   return failure();
559 }
560 
561 LogicalResult WarpgroupMmaOp::verify() {
562   if (getTransposeA() && !getTransposeB())
563     return emitOpError()
564            << "supports non-transpose A (Row Major) "
565               "and transpose B (Column Major) for the time being ";
566   MemRefType matrixA = getDescriptorA().getType().getTensor();
567   MemRefType matrixB = getDescriptorB().getType().getTensor();
568   VectorType matrixC = getMatrixC().getType().getFragmented();
569   VectorType matrixD = getMatrixD().getType().getFragmented();
570 
571   if (matrixC != matrixD)
572     return emitOpError() << "type of matrix C and matrix D must be the same";
573 
574   if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
575       matrixC.getRank() != 2 || matrixD.getRank() != 2) {
576     return emitOpError()
577            << "has matrices A, B, C and D, they must be 2 dimensional";
578   }
579 
580   if (matrixA.getShape()[1] != matrixB.getShape()[0])
581     return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
582                          << ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
583                          << " )";
584   if (matrixA.getShape()[0] != matrixC.getShape()[0])
585     return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
586                          << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
587                          << " )";
588   if (matrixB.getShape()[1] != matrixC.getShape()[1])
589     return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1]
590                          << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
591                          << " )";
592 
593   if (failed(isAllowedWGMMADataType(matrixC.getElementType(),
594                                     matrixA.getElementType(),
595                                     matrixB.getElementType())))
596     return emitOpError() << matrixC.getElementType()
597                          << " += " << matrixA.getElementType() << " * "
598                          << matrixB.getElementType()
599                          << ", it is not supported.";
600   // Check N
601   if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) {
602     return emitOpError() << "has input type " << matrixB << " n is set to "
603                          << matrixB.getDimSize(1) << ", it is not supported";
604   }
605 
606   // Currently, f16/bf16 supported
607   if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
608       !matrixA.getElementType().isBF16()) {
609     return emitOpError() << "hit a limitation: " << matrixC.getElementType()
610                          << " += " << matrixA.getElementType() << " * "
611                          << matrixB.getElementType()
612                          << ", it is not supported yet";
613   }
614 
615   return success();
616 }
617 
618 LogicalResult WarpgroupMmaStoreOp::verify() {
619   MemRefType dstMemrefType = getDstMemref().getType();
620   VectorType vtype = getMatrixD().getType().getFragmented();
621 
622   // Limitation
623   if (!vtype.getElementType().isF32()) {
624     return emitOpError()
625            << "hit a limitation: only f32 results for the time being";
626   }
627   if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
628       vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
629     return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
630                          << "] values. However, destination memref["
631                          << dstMemrefType.getDimSize(0) << "]["
632                          << dstMemrefType.getDimSize(1)
633                          << "]  does not have same size as results";
634   }
635   return success();
636 }
637 
638 //===----------------------------------------------------------------------===//
639 // WarpgroupMmaInitAccumulatorOp
640 //===----------------------------------------------------------------------===//
641 
642 LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
643 
644   nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
645   int64_t sizeM = accType.getFragmented().getDimSize(0);
646   int64_t sizeN = accType.getFragmented().getDimSize(1);
647   Type elemType = accType.getFragmented().getElementType();
648 
649   if (failed(isAllowedSizeM(sizeM)) ||
650       failed(isAllowedSizeN(sizeN, elemType))) {
651     return emitOpError() << "has type " << accType.getFragmented()
652                          << ". It does not fit into warp-group "
653                             "level (wgmma) matrix multiplication instruction "
654                             "(or not supported yet)";
655   }
656   return success();
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // RcpOp
661 //===----------------------------------------------------------------------===//
662 
663 LogicalResult RcpOp::verify() {
664   RcpRoundingModeAttr rounding = getRoundingAttr();
665   bool ftz = getFtz();
666   // Currently, only `rcp_approx` and `ftz` is supported.
667   if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
668     return emitOpError() << "has a limitation. " << rounding
669                          << " or non-ftz is not supported yet.";
670   }
671   return success();
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // TableGen'd dialect, type, and op definitions
676 //===----------------------------------------------------------------------===//
677 
678 #define GET_ATTRDEF_CLASSES
679 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc"
680 
681 #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc"
682 
683 #define GET_OP_CLASSES
684 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
685 
686 #define GET_TYPEDEF_CLASSES
687 #include "mlir/Dialect/NVGPU/IR/NVGPUTypes.cpp.inc"
688