1 //===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the NVGPU dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/DialectImplementation.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/OpImplementation.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/TypeUtilities.h" 25 #include "mlir/IR/Verifier.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/ADT/StringExtras.h" 28 #include "llvm/ADT/TypeSwitch.h" 29 30 using namespace mlir; 31 using namespace mlir::nvgpu; 32 33 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc" 34 35 void nvgpu::NVGPUDialect::initialize() { 36 addTypes< 37 #define GET_TYPEDEF_LIST 38 #include "mlir/Dialect/NVGPU/IR/NVGPUTypes.cpp.inc" 39 >(); 40 addAttributes< 41 #define GET_ATTRDEF_LIST 42 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc" 43 >(); 44 addOperations< 45 #define GET_OP_LIST 46 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" 47 >(); 48 } 49 50 bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { 51 if (!memorySpace) 52 return false; 53 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace)) 54 return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace; 55 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace)) 56 return gpuAttr.getValue() == gpu::AddressSpace::Workgroup; 57 return false; 58 } 59 60 bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) { 61 Attribute memorySpace = type.getMemorySpace(); 62 return isSharedMemoryAddressSpace(memorySpace); 63 } 64 65 //===----------------------------------------------------------------------===// 66 // NVGPU_DeviceAsyncCopyOp 67 //===----------------------------------------------------------------------===// 68 69 LogicalResult DeviceAsyncCopyOp::verify() { 70 auto srcMemref = llvm::cast<MemRefType>(getSrc().getType()); 71 auto dstMemref = llvm::cast<MemRefType>(getDst().getType()); 72 73 if (!srcMemref.isLastDimUnitStride()) 74 return emitError("source memref most minor dim must have unit stride"); 75 if (!dstMemref.isLastDimUnitStride()) 76 return emitError("destination memref most minor dim must have unit stride"); 77 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) 78 return emitError() 79 << "destination memref must have a memory space attribute of " 80 "IntegerAttr(" 81 << NVGPUDialect::kSharedMemoryAddressSpace 82 << ") or gpu::AddressSpaceAttr(Workgroup)"; 83 if (dstMemref.getElementType() != srcMemref.getElementType()) 84 return emitError("source and destination must have the same element type"); 85 if (size_t(srcMemref.getRank()) != getSrcIndices().size()) 86 return emitOpError() << "expected " << srcMemref.getRank() 87 << " source indices, got " << getSrcIndices().size(); 88 if (size_t(dstMemref.getRank()) != getDstIndices().size()) 89 return emitOpError() << "expected " << dstMemref.getRank() 90 << " destination indices, got " 91 << getDstIndices().size(); 92 int64_t dstElements = getDstElements().getZExtValue(); 93 int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * dstElements) / 8; 94 if (sizeInBytes != 4 && sizeInBytes != 8 && sizeInBytes != 16) { 95 unsigned dstWidth = dstMemref.getElementTypeBitWidth(); 96 InFlightDiagnostic diag = emitError(); 97 diag << "Requested copy elements is " << dstElements << " with width " 98 << dstMemref.getElementTypeBitWidth() 99 << ". But copy elements could be one of "; 100 if ((32 / dstWidth) > 0) 101 diag << (32 / dstWidth) << ", "; 102 if ((64 / dstWidth) > 0) 103 diag << (64 / dstWidth) << ", "; 104 if ((128 / dstWidth) > 0) 105 diag << (128 / dstWidth) << "."; 106 return diag; 107 } 108 if (getBypassL1().has_value()) { 109 int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth(); 110 if (getBypassL1().value() && sizeInBytes != 16) { 111 return emitOpError() << "bypassL1 does not satify alignment for " 112 << dstMemref << " with destination element " 113 << dstElements 114 << ". Unset bypassL1, or set " 115 "destination element to " 116 << req; 117 } 118 } 119 return success(); 120 } 121 122 //===----------------------------------------------------------------------===// 123 // NVGPU_MmaSyncOp 124 //===----------------------------------------------------------------------===// 125 void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder, 126 ::mlir::OperationState &odsState, Value matrixA, 127 Value matrixB, Value matrixC, ArrayAttr mmaShape) { 128 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC, 129 mmaShape, UnitAttr()); 130 } 131 132 void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder, 133 ::mlir::OperationState &odsState, Value matrixA, 134 Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape, 135 bool tf32Enabled) { 136 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC, 137 odsBuilder.getI64ArrayAttr(mmaShape), 138 tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr()); 139 } 140 141 /// Performs verification for MmaSyncOp and MmaSparseSyncOp. 142 static LogicalResult verifyMmaSyncOp(Operation *op, 143 TypedValue<VectorType> matrixA, 144 TypedValue<VectorType> matrixB, 145 TypedValue<VectorType> matrixC, 146 const std::array<int64_t, 3> &mmaShape, 147 bool tf32Enabled, bool sparse = false) { 148 149 // The verification for mma.sync covering various shapes and data types is 150 // based on the fundamental tensor core shape. 151 152 // "Fundamental" tensor core shapes: 153 // - For F32 (TF32), F16, S8, and S4 data 154 // types the fundamental tensor core operation is of shape 8-by-8-by-128b. 155 // - F64 is an exception and is of shape 8-by-8-by-256b. 156 int64_t shapeM = 8; 157 int64_t shapeN = 8; 158 int64_t shapeK; // set based on data type (128b for all data types except F64) 159 160 // Number of elements A, B, and C per thread per fundamental tensor core tile 161 int64_t numElementA; // set based on data type (32b except F64) 162 int64_t numElementB; // set based on data type (32b except F64) 163 int64_t numElementC{2}; // two accumulator elements per fundamental tile 164 165 // nvgpu.mma.sync vector operands (per thread) 166 auto aVector = matrixA.getType(); 167 auto bVector = matrixB.getType(); 168 auto cVector = matrixC.getType(); 169 170 // vector shapes 171 ArrayRef<int64_t> aShape = aVector.getShape(); 172 ArrayRef<int64_t> bShape = bVector.getShape(); 173 ArrayRef<int64_t> cShape = cVector.getShape(); 174 175 // vector element type 176 Type aType = aVector.getElementType(); 177 178 // Certain data types are not allowed in sparse mode. 179 if (sparse && aType.isF64()) 180 return op->emitError() << "f64 is not supported for sparse mode"; 181 182 if (aType.isF64()) { 183 // exception to 8-by-8-128b fundamental tensor core tile size 184 shapeK = 4; 185 numElementA = 1; 186 numElementB = 1; 187 } else if (aType.isF32() || aType.isBF16() || aType.isF16() || 188 aType.isInteger(8) || aType.isInteger(4)) { 189 // 8-by-8-128b fundamental tensor core tile size 190 int operandBitwidth = aType.getIntOrFloatBitWidth(); 191 shapeK = 128 / operandBitwidth; // 128b wide shapeK 192 193 numElementA = 32 / operandBitwidth; // 32b wide operand A 194 numElementB = 32 / operandBitwidth; // 32b wide operand B 195 } else { 196 return op->emitError() 197 << "expected input data type (i4,i8,f16,bf16,tf32,f64) " 198 "supported by " 199 << op->getName(); 200 } 201 202 // 203 // Basic verification 204 // 205 206 if (aShape.size() != 2) { 207 return op->emitError() << "matrixA must be 2 dimensional vector"; 208 } 209 210 if (bShape.size() != 2) { 211 return op->emitError() << "matrixB must be 2 dimensional vector"; 212 } 213 214 if (cShape.size() != 2) { 215 return op->emitError() << "matrixC must be 2 dimensional vector"; 216 } 217 218 auto [m, n, k] = mmaShape; 219 220 // verify warp-wide size for vector a 221 int64_t sparseFactor = sparse ? 2 : 1; 222 if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor) 223 return op->emitOpError() 224 << "expected " << m * k << " warp-wide matrix A elements"; 225 226 // verify warp-wide size for vector b 227 if (bShape[0] * bShape[1] * kWarpSize != k * n) 228 return op->emitOpError() 229 << "expected " << k * n << " warp-wide matrix B elements"; 230 231 // verify warp-wide size for vector c 232 if (cShape[0] * cShape[1] * kWarpSize != m * n) 233 return op->emitOpError() 234 << "expected " << m * n << " warp-wide matrix C elements"; 235 236 // verify tf32 tensor cores are enabled for only F32 datatype 237 if (tf32Enabled && !(aType.isF32())) 238 return op->emitOpError() 239 << "expected tf32 tensor cores only for F32 operands"; 240 241 // 242 // Extended verification 243 // 244 245 // tiles of fundamental tensor core operations 246 int64_t mTile = m / shapeM; 247 int64_t nTile = n / shapeN; 248 int64_t kTile = k / shapeK; 249 250 // verify shape of aVector 251 if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) || 252 (aShape[1] != numElementA)) 253 return op->emitOpError() << "expected matrix A to be shaped (" 254 << mTile * kTile << " x " << numElementA << ")"; 255 256 // verify shape of bVector 257 if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB)) 258 return op->emitOpError() << "expected matrix B to be shaped (" 259 << kTile * nTile << " x " << numElementB << ")"; 260 261 // verify shape of cVector 262 if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC)) 263 return op->emitOpError() << "expected matrix C to be shaped (" 264 << mTile * nTile << " x " << numElementC << ")"; 265 266 return success(); 267 } 268 269 LogicalResult MmaSyncOp::verify() { 270 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(), 271 getMatrixC(), getMmaShapeAsArray(), 272 getOperation()->hasAttr(getTf32EnabledAttrName())); 273 } 274 275 //===----------------------------------------------------------------------===// 276 // NVGPU_MmaSparseSyncOp 277 //===----------------------------------------------------------------------===// 278 void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder, 279 ::mlir::OperationState &odsState, Value matrixA, 280 Value matrixB, Value matrixC, Value sparseMetadata, 281 ArrayRef<int64_t> mmaShape) { 282 build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC, 283 sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr()); 284 } 285 286 LogicalResult MmaSparseSyncOp::verify() { 287 unsigned sparsitySelector = getSparsitySelector(); 288 if (sparsitySelector > 1) 289 return emitOpError() << "sparsity selector should be 0 or 1"; 290 return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(), 291 getMatrixC(), getMmaShapeAsArray(), 292 getOperation()->hasAttr(getTf32EnabledAttrName()), 293 true); 294 } 295 296 //===----------------------------------------------------------------------===// 297 // NVGPU_LdMatrixOp 298 //===----------------------------------------------------------------------===// 299 LogicalResult LdMatrixOp::verify() { 300 301 // ldmatrix reads data from source in shared memory 302 auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType()); 303 304 // ldmatrix writes data to result/destination in vector registers 305 auto resVector = llvm::cast<VectorType>(getRes().getType()); 306 307 // vector register shape, element type, and bitwidth 308 ArrayRef<int64_t> resShape = resVector.getShape(); 309 Type resType = resVector.getElementType(); 310 int64_t elementBitWidth = resType.getIntOrFloatBitWidth(); 311 312 // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread 313 int64_t numElementsPer32b = 32 / elementBitWidth; 314 315 // number of 8-by-8 tiles 316 int64_t numTiles = getNumTiles(); 317 318 // transpose elements in vector registers at 16b granularity when true 319 bool isTranspose = getTranspose(); 320 321 // 322 // verification 323 // 324 325 if (!NVGPUDialect::hasSharedMemoryAddressSpace(srcMemref)) 326 return emitError() 327 << "expected nvgpu.ldmatrix srcMemref must have a memory space " 328 "attribute of IntegerAttr(" 329 << NVGPUDialect::kSharedMemoryAddressSpace 330 << ") or gpu::AddressSpaceAttr(Workgroup)"; 331 if (elementBitWidth > 32) 332 return emitError() << "nvgpu.ldmatrix works for 32b or lower"; 333 if (isTranspose && !(elementBitWidth == 16)) 334 return emitError() 335 << "nvgpu.ldmatrix transpose works only at 16b granularity"; 336 if (resShape.size() != 2) { 337 return emitError() << "results must be 2 dimensional vector"; 338 } 339 if (!(resShape[1] == numElementsPer32b)) 340 return emitError() << "expected vector register shape[1] = " 341 << numElementsPer32b; 342 if (!(resShape[0] == numTiles)) 343 return emitError() 344 << "expected vector register shape[0] and numTiles to match"; 345 346 return success(); 347 } 348 349 //===----------------------------------------------------------------------===// 350 // NVGPU_TmaAsyncLoadOp 351 //===----------------------------------------------------------------------===// 352 353 std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref( 354 Operation *op, nvgpu::TensorMapDescriptorType descType, 355 std::optional<MemRefType> memrefType = std::nullopt) { 356 MemRefType descMemref = descType.getTensor(); 357 // Limitation 358 if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE) 359 return op->emitError() << "Interleave options are not supported yet."; 360 361 // Address space check for shared memory check 362 if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) { 363 return op->emitError() << "the tensor map descriptor has incorrect address " 364 "space, it must be shared memory address space."; 365 } 366 // Support only static shape for the time being 367 if (!descMemref.hasStaticShape()) 368 return op->emitError() << "the tensor map descriptor must be static shaped"; 369 370 for (auto dim : descMemref.getShape()) { 371 if (dim <= 0 || dim > kMaxTMADimension) { 372 return op->emitError() << "the tensor map descriptor must have " 373 "dimensions between 1 and " 374 << kMaxTMADimension << " but it is " << dim; 375 } 376 } 377 if (descMemref.getRank() > 1 && 378 descType.getSwizzle() != TensorMapSwizzleKind::SWIZZLE_NONE) { 379 unsigned lastDimensionByte = 380 descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8; 381 if (lastDimensionByte != kMaxTMALastdimByte) 382 return op->emitError() << "the tensormap descriptor must have last " 383 "dimension of " 384 << kMaxTMALastdimByte << " bytes but it is " 385 << lastDimensionByte << " bytes"; 386 } 387 388 // No verification if memref type is not provided 389 if (!memrefType.has_value()) 390 return std::nullopt; 391 392 MemRefType dstMemref = memrefType.value(); 393 394 // Check element type 395 if (descMemref.getElementType() != dstMemref.getElementType()) { 396 return op->emitError() << "the element type of tensor map descriptor and " 397 "memref must be same"; 398 } 399 400 if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) { 401 return op->emitError() << "the destination memref has incorrect address " 402 "space, it must be shared memory address space."; 403 } 404 if (!dstMemref.hasStaticShape()) 405 return op->emitError() << "the destination memref must be static shaped"; 406 407 if (dstMemref.getRank() != descMemref.getRank()) { 408 return op->emitError() << "the shape of tensor map descriptor and " 409 "memref must have same rank"; 410 } 411 if (!descMemref.getShape().equals(dstMemref.getShape())) { 412 return op->emitError() << "memref and tensor map shapes mismatch " 413 << descMemref << " != " << dstMemref; 414 } 415 416 return std::nullopt; 417 } 418 419 LogicalResult TmaAsyncLoadOp::verify() { 420 std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref( 421 *this, getTensorMapDescriptor().getType(), getDst().getType()); 422 if (error.has_value()) 423 return error.value(); 424 425 if (getCoordinates().size() > kMaxTMATensorDimension) { 426 return emitError() << "Maximum " << kMaxTMATensorDimension 427 << " coordinates are supported."; 428 } 429 if (getCoordinates().size() != 430 size_t(getTensorMapDescriptor().getType().getTensor().getRank())) { 431 return emitError() << "number of coordinates do not match with the rank of " 432 "tensor descriptor map."; 433 } 434 435 return success(); 436 } 437 438 //===----------------------------------------------------------------------===// 439 // NVGPU_TmaAsyncStoreOp 440 //===----------------------------------------------------------------------===// 441 442 LogicalResult TmaAsyncStoreOp::verify() { 443 std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref( 444 *this, getTensorMapDescriptor().getType(), getSrc().getType()); 445 if (error.has_value()) 446 return error.value(); 447 448 if (getCoordinates().size() > kMaxTMATensorDimension) { 449 return emitError() << "Maximum " << kMaxTMATensorDimension 450 << " coordinates are supported."; 451 } 452 if (getCoordinates().size() != 453 size_t(getTensorMapDescriptor().getType().getTensor().getRank())) { 454 return emitError() << "number of coordinates do not match with the rank of " 455 "tensor descriptor map."; 456 } 457 458 return success(); 459 } 460 461 LogicalResult TmaCreateDescriptorOp::verify() { 462 if (getBoxDimensions().size() > kMaxTMATensorDimension) { 463 return emitError() << "Maximum " << kMaxTMATensorDimension 464 << " coordinates are supported."; 465 } 466 467 std::optional<InFlightDiagnostic> error = 468 verifyTmaDescriptorWithMemref(*this, getTensorMap().getType()); 469 if (error.has_value()) 470 return error.value(); 471 472 return success(); 473 } 474 475 //===----------------------------------------------------------------------===// 476 // NVGPU_WarpgroupGenerateDescriptorOp 477 //===----------------------------------------------------------------------===// 478 479 LogicalResult WarpgroupGenerateDescriptorOp::verify() { 480 std::optional<InFlightDiagnostic> error = 481 verifyTmaDescriptorWithMemref(*this, getTensorMap().getType()); 482 if (error.has_value()) 483 return error.value(); 484 485 if (getTensorMap().getType().getSwizzle() != 486 TensorMapSwizzleKind::SWIZZLE_128B) { 487 return emitError() << "supports only " 488 << stringifyTensorMapSwizzleKind( 489 TensorMapSwizzleKind::SWIZZLE_128B) 490 << " is supported for the time being"; 491 } 492 493 if (getTensorMap().getType().getInterleave() != 494 TensorMapInterleaveKind::INTERLEAVE_NONE) { 495 return emitError() << "supports only " 496 << stringifyTensorMapInterleaveKind( 497 TensorMapInterleaveKind::INTERLEAVE_NONE) 498 << " is supported for the time being"; 499 } 500 501 return success(); 502 } 503 504 //===----------------------------------------------------------------------===// 505 // WarpgroupMmaOp 506 //===----------------------------------------------------------------------===// 507 508 LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) { 509 // F32 += F16 + F16 510 // F16 += F16 + F16 511 if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16())) 512 return success(); 513 // F32 += TF32 + TF32 514 if (typeA.isTF32() && typeD.isF32() && typeB.isTF32()) 515 return success(); 516 // s32 += i8 + i8 517 if (typeA.isInteger(16) && typeB.isInteger(16) && typeD.isInteger(32)) 518 return success(); 519 // s32 += i1 + i1 520 if (typeA.isInteger(1) && typeB.isInteger(1) && typeD.isInteger(32)) 521 return success(); 522 // F32 += BF16 + BF16 523 // F16 += BF16 + BF16 524 if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16())) 525 return success(); 526 // F16 += f8 + f8 527 // F32 += f8 + f8 528 if (isa<Float8E5M2Type, Float8E4M3FNType>(typeA) && 529 isa<Float8E5M2Type, Float8E4M3FNType>(typeB) && 530 (typeD.isF32() || typeD.isF16())) 531 return success(); 532 533 return failure(); 534 } 535 536 LogicalResult isAllowedSizeM(int sizeM) { 537 if (sizeM % kWgmmaSizeM) 538 return failure(); 539 return success(); 540 } 541 542 LogicalResult isAllowedSizeN(int sizeN, Type typeA) { 543 SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64, 544 72, 80, 88, 96, 104, 112, 120, 128, 545 136, 144, 152, 160, 168, 176, 184, 192, 546 200, 208, 216, 224, 232, 240, 248, 256}; 547 SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64, 548 80, 96, 112, 128, 144, 160, 549 176, 192, 208, 224, 240, 256}; 550 if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() || 551 isa<Float8E5M2Type, Float8E4M3FNType>(typeA)) 552 if (llvm::is_contained(allowedN, sizeN)) 553 return success(); 554 555 if (typeA.isInteger(8) || typeA.isInteger(1)) 556 if (llvm::is_contained(allowedNshort, sizeN)) 557 return success(); 558 return failure(); 559 } 560 561 LogicalResult WarpgroupMmaOp::verify() { 562 if (getTransposeA() && !getTransposeB()) 563 return emitOpError() 564 << "supports non-transpose A (Row Major) " 565 "and transpose B (Column Major) for the time being "; 566 MemRefType matrixA = getDescriptorA().getType().getTensor(); 567 MemRefType matrixB = getDescriptorB().getType().getTensor(); 568 VectorType matrixC = getMatrixC().getType().getFragmented(); 569 VectorType matrixD = getMatrixD().getType().getFragmented(); 570 571 if (matrixC != matrixD) 572 return emitOpError() << "type of matrix C and matrix D must be the same"; 573 574 if (matrixA.getRank() != 2 || matrixB.getRank() != 2 || 575 matrixC.getRank() != 2 || matrixD.getRank() != 2) { 576 return emitOpError() 577 << "has matrices A, B, C and D, they must be 2 dimensional"; 578 } 579 580 if (matrixA.getShape()[1] != matrixB.getShape()[0]) 581 return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1] 582 << ")!= 1st dim matrix-B (" << matrixB.getShape()[0] 583 << " )"; 584 if (matrixA.getShape()[0] != matrixC.getShape()[0]) 585 return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0] 586 << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0] 587 << " )"; 588 if (matrixB.getShape()[1] != matrixC.getShape()[1]) 589 return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1] 590 << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1] 591 << " )"; 592 593 if (failed(isAllowedWGMMADataType(matrixC.getElementType(), 594 matrixA.getElementType(), 595 matrixB.getElementType()))) 596 return emitOpError() << matrixC.getElementType() 597 << " += " << matrixA.getElementType() << " * " 598 << matrixB.getElementType() 599 << ", it is not supported."; 600 // Check N 601 if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) { 602 return emitOpError() << "has input type " << matrixB << " n is set to " 603 << matrixB.getDimSize(1) << ", it is not supported"; 604 } 605 606 // Currently, f16/bf16 supported 607 if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() && 608 !matrixA.getElementType().isBF16()) { 609 return emitOpError() << "hit a limitation: " << matrixC.getElementType() 610 << " += " << matrixA.getElementType() << " * " 611 << matrixB.getElementType() 612 << ", it is not supported yet"; 613 } 614 615 return success(); 616 } 617 618 LogicalResult WarpgroupMmaStoreOp::verify() { 619 MemRefType dstMemrefType = getDstMemref().getType(); 620 VectorType vtype = getMatrixD().getType().getFragmented(); 621 622 // Limitation 623 if (!vtype.getElementType().isF32()) { 624 return emitOpError() 625 << "hit a limitation: only f32 results for the time being"; 626 } 627 if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) || 628 vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) { 629 return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1) 630 << "] values. However, destination memref[" 631 << dstMemrefType.getDimSize(0) << "][" 632 << dstMemrefType.getDimSize(1) 633 << "] does not have same size as results"; 634 } 635 return success(); 636 } 637 638 //===----------------------------------------------------------------------===// 639 // WarpgroupMmaInitAccumulatorOp 640 //===----------------------------------------------------------------------===// 641 642 LogicalResult WarpgroupMmaInitAccumulatorOp::verify() { 643 644 nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType(); 645 int64_t sizeM = accType.getFragmented().getDimSize(0); 646 int64_t sizeN = accType.getFragmented().getDimSize(1); 647 Type elemType = accType.getFragmented().getElementType(); 648 649 if (failed(isAllowedSizeM(sizeM)) || 650 failed(isAllowedSizeN(sizeN, elemType))) { 651 return emitOpError() << "has type " << accType.getFragmented() 652 << ". It does not fit into warp-group " 653 "level (wgmma) matrix multiplication instruction " 654 "(or not supported yet)"; 655 } 656 return success(); 657 } 658 659 //===----------------------------------------------------------------------===// 660 // RcpOp 661 //===----------------------------------------------------------------------===// 662 663 LogicalResult RcpOp::verify() { 664 RcpRoundingModeAttr rounding = getRoundingAttr(); 665 bool ftz = getFtz(); 666 // Currently, only `rcp_approx` and `ftz` is supported. 667 if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) { 668 return emitOpError() << "has a limitation. " << rounding 669 << " or non-ftz is not supported yet."; 670 } 671 return success(); 672 } 673 674 //===----------------------------------------------------------------------===// 675 // TableGen'd dialect, type, and op definitions 676 //===----------------------------------------------------------------------===// 677 678 #define GET_ATTRDEF_CLASSES 679 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc" 680 681 #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc" 682 683 #define GET_OP_CLASSES 684 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" 685 686 #define GET_TYPEDEF_CLASSES 687 #include "mlir/Dialect/NVGPU/IR/NVGPUTypes.cpp.inc" 688