1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" 10 11 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 13 #include "mlir/Conversion/LLVMCommon/Pattern.h" 14 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 19 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 22 #include "mlir/Dialect/SCF/Transforms/Patterns.h" 23 #include "mlir/IR/BuiltinTypes.h" 24 #include "mlir/IR/ImplicitLocOpBuilder.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/IR/TypeUtilities.h" 27 #include "mlir/IR/Value.h" 28 #include "mlir/Pass/Pass.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/ErrorHandling.h" 31 #include "llvm/Support/raw_ostream.h" 32 #include <optional> 33 34 #define DEBUG_TYPE "nvgpu-to-nvvm" 35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 36 #define DBGSE() (llvm::dbgs()) 37 38 namespace mlir { 39 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS 40 #include "mlir/Conversion/Passes.h.inc" 41 } // namespace mlir 42 43 using namespace mlir; 44 45 /// Number of bits that needs to be excluded when building matrix descriptor for 46 /// wgmma operations. 47 constexpr int exclude4LSB = 4; 48 49 /// GPU has 32 bit registers, this function truncates values when larger width 50 /// is not needed. 51 static Value truncToI32(ImplicitLocOpBuilder &b, Value value) { 52 Type type = value.getType(); 53 assert(llvm::isa<IntegerType>(type) && "expected an integer Value"); 54 if (type.getIntOrFloatBitWidth() <= 32) 55 return value; 56 return b.create<LLVM::TruncOp>(b.getI32Type(), value); 57 } 58 59 /// Returns the type for the intrinsic given the vectorResultType of the 60 /// `gpu.mma.sync` operation. 61 static Type inferIntrinsicResultType(Type vectorResultType) { 62 MLIRContext *ctx = vectorResultType.getContext(); 63 auto a = cast<LLVM::LLVMArrayType>(vectorResultType); 64 auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); 65 auto i32Ty = IntegerType::get(ctx, 32); 66 auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); 67 Type f64Ty = Float64Type::get(ctx); 68 Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); 69 Type f32Ty = Float32Type::get(ctx); 70 Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); 71 if (a.getElementType() == f16x2Ty) { 72 return LLVM::LLVMStructType::getLiteral( 73 ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty)); 74 } 75 if (a.getElementType() == i32x2Ty) { 76 return LLVM::LLVMStructType::getLiteral( 77 ctx, 78 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, i32Ty)); 79 } 80 if (a.getElementType() == f64x2Ty) { 81 return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); 82 } 83 if (a.getElementType() == f32x2Ty) { 84 return LLVM::LLVMStructType::getLiteral( 85 ctx, 86 SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty)); 87 } 88 if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) { 89 return LLVM::LLVMStructType::getLiteral( 90 ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty)); 91 } 92 return vectorResultType; 93 } 94 95 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is 96 /// always an LLVM struct) into a fragment that is compatible with the vector 97 /// type of this operation. This involves extracting elements from the struct 98 /// and inserting them into an LLVM array. These extra data-movement 99 /// operations should be canonicalized away by the LLVM backend. 100 static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, 101 Type resultType, Value intrinsicResult, 102 RewriterBase &rewriter) { 103 MLIRContext *ctx = rewriter.getContext(); 104 auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType); 105 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType); 106 Type i32Ty = rewriter.getI32Type(); 107 Type f32Ty = rewriter.getF32Type(); 108 Type f64Ty = rewriter.getF64Type(); 109 Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); 110 Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); 111 Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); 112 Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2); 113 Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); 114 115 auto makeConst = [&](int32_t index) -> Value { 116 return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32), 117 rewriter.getI32IntegerAttr(index)); 118 }; 119 120 if (arrayType) { 121 SmallVector<Value, 4> elements; 122 123 // The intrinsic returns 32-bit wide elements in a form which can be 124 // directly bitcasted and inserted into the result vector. 125 if (arrayType.getElementType() == f16x2Ty || 126 arrayType.getElementType() == f32x1Ty) { 127 for (unsigned i = 0; i < structType.getBody().size(); i++) { 128 Value el = 129 rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i); 130 el = rewriter.createOrFold<LLVM::BitcastOp>( 131 loc, arrayType.getElementType(), el); 132 elements.push_back(el); 133 } 134 } 135 136 // The intrinsic returns i32, f64, and f32 values as individual scalars, 137 // even when the result is notionally a 64-bit wide element (e.g. f32x2). We 138 // need to extract them from the struct and pack them into the 64-bit wide 139 // rows of the vector result. 140 if (arrayType.getElementType() == i32x2Ty || 141 arrayType.getElementType() == f64x2Ty || 142 arrayType.getElementType() == f32x2Ty) { 143 144 for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { 145 Value vec = 146 rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType()); 147 Value x1 = 148 rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, i * 2); 149 Value x2 = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsicResult, 150 i * 2 + 1); 151 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, 152 x1, makeConst(0)); 153 vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec, 154 x2, makeConst(1)); 155 elements.push_back(vec); 156 } 157 } 158 159 // Create the final vectorized result. 160 Value result = rewriter.create<LLVM::UndefOp>(loc, arrayType); 161 for (const auto &el : llvm::enumerate(elements)) { 162 result = rewriter.create<LLVM::InsertValueOp>(loc, result, el.value(), 163 el.index()); 164 } 165 return result; 166 } 167 168 return intrinsicResult; 169 } 170 171 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be 172 /// given as 2D `vectors` where the rows are 32b or 64b wide. The 173 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of 174 /// scalars of certain types. This function helps unpack the `vector` arguments 175 /// and cast them to the types expected by `nvvm.mma.sync`. 176 static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b, 177 Value operand, 178 NVVM::MMATypes operandPtxType) { 179 SmallVector<Value> result; 180 Type i32Ty = b.getI32Type(); 181 Type f64Ty = b.getF64Type(); 182 Type f32Ty = b.getF32Type(); 183 Type i64Ty = b.getI64Type(); 184 Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4); 185 Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8); 186 Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); 187 auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType()); 188 189 for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { 190 Value toUse = b.create<LLVM::ExtractValueOp>(operand, i); 191 192 // For 4xi8 vectors, the intrinsic expects these to be provided as i32 193 // scalar types. 194 if (arrayTy.getElementType() == i8x4Ty || 195 arrayTy.getElementType() == i4x8Ty || 196 (arrayTy.getElementType() == f32x1Ty && 197 operandPtxType == NVVM::MMATypes::tf32)) { 198 result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse)); 199 continue; 200 } 201 202 // For some element types (i32, f32, f64), we need to unpack the inner 203 // vector/array type as well because the intrinsic expects individual 204 // scalars to be provided. 205 VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType()); 206 if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || 207 innerArrayTy.getElementType() == f64Ty || 208 innerArrayTy.getElementType() == f32Ty)) { 209 for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); 210 idx < innerSize; idx++) { 211 result.push_back(b.create<LLVM::ExtractElementOp>( 212 toUse, 213 b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx)))); 214 } 215 continue; 216 } 217 result.push_back(toUse); 218 } 219 return result; 220 } 221 222 /// Returns whether mbarrier object has shared memory address space. 223 static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType) { 224 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace( 225 barrierType.getMemorySpace())); 226 } 227 228 /// Returns the memory space attribute of the mbarrier object. 229 Attribute nvgpu::getMbarrierMemorySpace(MLIRContext *context, 230 nvgpu::MBarrierGroupType barrierType) { 231 Attribute memorySpace = {}; 232 if (isMbarrierShared(barrierType)) { 233 memorySpace = 234 IntegerAttr::get(IntegerType::get(context, 64), 235 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace); 236 } 237 return memorySpace; 238 } 239 240 /// Returns memref type of the mbarrier object. The type is defined in the 241 /// MBarrierGroupType. 242 MemRefType nvgpu::getMBarrierMemrefType(MLIRContext *context, 243 nvgpu::MBarrierGroupType barrierType) { 244 Attribute memorySpace = nvgpu::getMbarrierMemorySpace(context, barrierType); 245 MemRefLayoutAttrInterface layout; 246 return MemRefType::get({barrierType.getNumBarriers()}, 247 IntegerType::get(context, 64), layout, memorySpace); 248 } 249 250 namespace { 251 252 struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { 253 using ConvertOpToLLVMPattern<nvgpu::LdMatrixOp>::ConvertOpToLLVMPattern; 254 255 LogicalResult 256 matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor, 257 ConversionPatternRewriter &rewriter) const override { 258 MLIRContext *ctx = getContext(); 259 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 260 261 // The result type of ldmatrix will always be a struct of 32bit integer 262 // registers if more than one 32bit value is returned. Otherwise, the result 263 // is a single i32. The result type of the GPU operation is always a vector 264 // of shape (NumRegisters, VectorRegister) where VectorRegister is the 265 // vector type of the result and always 32 bits long. We bitcast the result 266 // of the NVVM::LdMatrix to this vector type. 267 auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]); 268 if (!vectorResultType) { 269 return failure(); 270 } 271 Type innerVectorType = LLVM::getFixedVectorType( 272 vectorResultType.getElementType(), vectorResultType.getDimSize(1)); 273 274 int64_t num32BitRegs = vectorResultType.getDimSize(0); 275 276 Type ldMatrixResultType; 277 if (num32BitRegs > 1) { 278 ldMatrixResultType = LLVM::LLVMStructType::getLiteral( 279 ctx, SmallVector<Type>(num32BitRegs, rewriter.getI32Type())); 280 } else { 281 ldMatrixResultType = rewriter.getI32Type(); 282 } 283 284 auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType()); 285 Value srcPtr = 286 getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(), 287 adaptor.getIndices(), rewriter); 288 Value ldMatrixResult = b.create<NVVM::LdMatrixOp>( 289 ldMatrixResultType, srcPtr, 290 /*num=*/op.getNumTiles(), 291 /*layout=*/op.getTranspose() ? NVVM::MMALayout::col 292 : NVVM::MMALayout::row); 293 294 // The ldmatrix operation returns either a single i32 value or a struct of 295 // i32 values. Here we unpack those values and cast them back to their 296 // actual vector type (still of width 32b) and repack them into a result 297 // struct. 298 Type finalResultType = typeConverter->convertType(vectorResultType); 299 Value result = b.create<LLVM::UndefOp>(finalResultType); 300 for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { 301 Value i32Register = 302 num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i) 303 : ldMatrixResult; 304 Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register); 305 result = b.create<LLVM::InsertValueOp>(result, casted, i); 306 } 307 308 rewriter.replaceOp(op, result); 309 return success(); 310 } 311 }; 312 313 /// Convert the given type into the corresponding PTX type (NVVM::MMATypes 314 /// enum). 315 static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) { 316 Type elType = getElementTypeOrSelf(t); 317 if (elType.isInteger(8)) 318 return NVVM::MMATypes::s8; 319 if (elType.isInteger(4)) 320 return NVVM::MMATypes::s4; 321 if (elType.isF16()) 322 return NVVM::MMATypes::f16; 323 if (elType.isF64()) 324 return NVVM::MMATypes::f64; 325 if (elType.isF32()) 326 return NVVM::MMATypes::tf32; 327 return failure(); 328 } 329 330 struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> { 331 using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern; 332 333 LogicalResult 334 matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor, 335 ConversionPatternRewriter &rewriter) const override { 336 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 337 // Get the shapes of the MMAMatrix type being used. The shapes will 338 // choose which intrinsic this op will be lowered to. 339 VectorType aType = op.getMatrixA().getType(); 340 VectorType bType = op.getMatrixA().getType(); 341 VectorType cType = op.getMatrixC().getType(); 342 343 std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray(); 344 345 // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32). 346 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); 347 if (aType.getElementType().isF32() && !tf32Enabled) 348 return failure(); 349 350 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType); 351 if (failed(ptxTypeA)) 352 return op->emitOpError("failed to deduce operand PTX types"); 353 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType); 354 if (failed(ptxTypeB)) 355 return op->emitOpError("failed to deduce operand PTX types"); 356 std::optional<NVVM::MMATypes> ptxTypeC = 357 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(), 358 /*isAccumulator=*/true); 359 if (!ptxTypeC) 360 return op->emitError( 361 "could not infer the PTX type for the accumulator/result"); 362 363 // TODO: add an attribute to the op to customize this behavior. 364 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt); 365 if (isa<IntegerType>(aType.getElementType())) 366 overflow = NVVM::MMAIntOverflow::satfinite; 367 368 SmallVector<Value> matA = 369 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA); 370 SmallVector<Value> matB = 371 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB); 372 SmallVector<Value> matC = 373 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC); 374 375 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); 376 Type intrinsicResTy = inferIntrinsicResultType( 377 typeConverter->convertType(op->getResultTypes()[0])); 378 Value intrinsicResult = b.create<NVVM::MmaOp>( 379 intrinsicResTy, matA, matB, matC, 380 /*shape=*/gemmShape, 381 /*b1Op=*/std::nullopt, 382 /*intOverflow=*/overflow, 383 /*multiplicandPtxTypes=*/ 384 std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB}, 385 /*multiplicandLayouts=*/ 386 std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row, 387 NVVM::MMALayout::col}); 388 rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, 389 desiredRetTy, intrinsicResult, 390 rewriter)); 391 return success(); 392 } 393 }; 394 395 struct ConvertNVGPUToNVVMPass 396 : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> { 397 using Base::Base; 398 399 void getDependentDialects(DialectRegistry ®istry) const override { 400 registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect, 401 arith::ArithDialect>(); 402 } 403 404 void runOnOperation() override { 405 LowerToLLVMOptions options(&getContext()); 406 RewritePatternSet patterns(&getContext()); 407 LLVMTypeConverter converter(&getContext(), options); 408 IRRewriter rewriter(&getContext()); 409 populateGpuMemorySpaceAttributeConversions( 410 converter, [](gpu::AddressSpace space) -> unsigned { 411 switch (space) { 412 case gpu::AddressSpace::Global: 413 return static_cast<unsigned>( 414 NVVM::NVVMMemorySpace::kGlobalMemorySpace); 415 case gpu::AddressSpace::Workgroup: 416 return static_cast<unsigned>( 417 NVVM::NVVMMemorySpace::kSharedMemorySpace); 418 case gpu::AddressSpace::Private: 419 return 0; 420 } 421 llvm_unreachable("unknown address space enum value"); 422 return 0; 423 }); 424 /// device-side async tokens cannot be materialized in nvvm. We just 425 /// convert them to a dummy i32 type in order to easily drop them during 426 /// conversion. 427 converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type { 428 return converter.convertType(IntegerType::get(type.getContext(), 32)); 429 }); 430 converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type { 431 Type elemType = type.getFragmented().getElementType(); 432 int64_t sizeM = type.getFragmented().getDimSize(0); 433 int64_t sizeN = type.getFragmented().getDimSize(1); 434 435 unsigned numMembers; 436 if (elemType.isF32() || elemType.isInteger(32)) 437 numMembers = sizeN / 2; 438 else if (elemType.isF16()) 439 numMembers = sizeN / 4; 440 else 441 llvm_unreachable("unsupported type for warpgroup accumulator"); 442 443 SmallVector<Type> innerStructBody; 444 for (unsigned i = 0; i < numMembers; i++) 445 innerStructBody.push_back(elemType); 446 auto innerStructType = 447 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody); 448 449 SmallVector<Type> structBody; 450 for (int i = 0; i < sizeM; i += kWgmmaSizeM) 451 structBody.push_back(innerStructType); 452 453 auto convertedType = 454 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); 455 return converter.convertType(convertedType); 456 }); 457 converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { 458 return converter.convertType(IntegerType::get(type.getContext(), 64)); 459 }); 460 converter.addConversion( 461 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { 462 return converter.convertType(IntegerType::get(type.getContext(), 64)); 463 }); 464 converter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type { 465 return converter.convertType( 466 nvgpu::getMBarrierMemrefType(rewriter.getContext(), type)); 467 }); 468 converter.addConversion([&](nvgpu::TensorMapDescriptorType type) -> Type { 469 return LLVM::LLVMPointerType::get(type.getContext()); 470 }); 471 populateNVGPUToNVVMConversionPatterns(converter, patterns); 472 LLVMConversionTarget target(getContext()); 473 target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); 474 target.addLegalDialect<::mlir::arith::ArithDialect>(); 475 target.addLegalDialect<::mlir::memref::MemRefDialect>(); 476 target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); 477 mlir::scf::populateSCFStructuralTypeConversionsAndLegality( 478 converter, patterns, target); 479 if (failed(applyPartialConversion(getOperation(), target, 480 std::move(patterns)))) 481 signalPassFailure(); 482 } 483 }; 484 485 /// Returns the constraints for the sparse MMA inline assembly instruction. 486 static std::string buildMmaSparseAsmConstraintString(unsigned matASize, 487 unsigned matBSize, 488 unsigned matCSize) { 489 std::string str; 490 llvm::raw_string_ostream ss(str); 491 for (unsigned i = 0; i < matCSize; i++) 492 ss << "=r,"; 493 for (unsigned i = 0; i < matASize + matBSize + matCSize; i++) 494 ss << "r,"; 495 // The final operand is for the sparsity metadata. 496 // The sparsity selector appears as direct literal. 497 ss << "r"; 498 return str; 499 } 500 501 /// Returns the string for the `mma.sp.sync` instruction that corresponds to 502 /// the given parameters. Note that this function doesn't do any validation, 503 /// it's expected that the provided parameters correspond to a valid 504 /// instruction. 505 static std::string buildMmaSparseAsmString( 506 const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize, 507 unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, 508 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, 509 std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) { 510 auto ptxTypeStr = [](NVVM::MMATypes ptxType) { 511 return NVVM::stringifyMMATypes(ptxType); 512 }; 513 514 std::string asmStr; 515 llvm::raw_string_ostream ss(asmStr); 516 ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k" 517 << shape[2] << ".row.col."; 518 519 if (overflow) 520 ss << NVVM::stringifyMMAIntOverflow(*overflow) << "."; 521 522 ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "." 523 << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " "; 524 unsigned asmArgIdx = 0; 525 526 // The operand string is structured into sections `{matC elements...}, 527 // {matA elements...}, {matB elements...}, {matC elements}`. 528 for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) { 529 ss << "{"; 530 for (unsigned i = 0; i < arrSize; i++) 531 ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : ""); 532 ss << "},"; 533 } 534 ss << "$" << asmArgIdx++ << ","; 535 assert(metaDataSelector <= 1); 536 ss << "0x" << metaDataSelector << ";"; 537 return asmStr; 538 } 539 540 /// Builds an inline assembly operation corresponding to the specified MMA 541 /// sparse sync operation. 542 static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm( 543 ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, 544 NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, 545 std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData, 546 ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData, 547 int64_t metadataSelector, const std::array<int64_t, 3> &shape, 548 Type intrinsicResultType) { 549 auto asmDialectAttr = 550 LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT); 551 552 const unsigned matASize = unpackedAData.size(); 553 const unsigned matBSize = unpackedB.size(); 554 const unsigned matCSize = unpackedC.size(); 555 556 std::string asmStr = buildMmaSparseAsmString( 557 shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC, 558 ptxTypeD, overflow, metadataSelector); 559 std::string constraintStr = 560 buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize); 561 562 SmallVector<Value> asmVals; 563 asmVals.reserve(matASize + matBSize + matCSize + 1); 564 for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC}) 565 llvm::append_range(asmVals, args); 566 asmVals.push_back(indexData); 567 568 return b.create<LLVM::InlineAsmOp>( 569 /*resultTypes=*/intrinsicResultType, 570 /*operands=*/asmVals, 571 /*asm_string=*/asmStr, 572 /*constraints=*/constraintStr, 573 /*has_side_effects=*/true, 574 /*is_align_stack=*/false, 575 /*asm_dialect=*/asmDialectAttr, 576 /*operand_attrs=*/ArrayAttr()); 577 } 578 579 /// Lowers `nvgpu.mma.sp.sync` to inline assembly. 580 struct NVGPUMmaSparseSyncLowering 581 : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> { 582 using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern; 583 584 LogicalResult 585 matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor, 586 ConversionPatternRewriter &rewriter) const override { 587 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 588 // Get the shapes of the MMAMatrix type being used. The shapes will 589 // choose which intrinsic this op will be lowered to. 590 VectorType aType = op.getMatrixA().getType(); 591 VectorType bType = op.getMatrixB().getType(); 592 VectorType cType = op.getMatrixC().getType(); 593 594 FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType); 595 if (failed(ptxTypeA)) 596 return op->emitOpError("failed to deduce operand PTX types"); 597 FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType); 598 if (failed(ptxTypeB)) 599 return op->emitOpError("failed to deduce operand PTX types"); 600 std::optional<NVVM::MMATypes> ptxTypeC = 601 NVVM::MmaOp::inferOperandMMAType(cType.getElementType(), 602 /*isAccumulator=*/true); 603 if (!ptxTypeC) 604 return op->emitError( 605 "could not infer the PTX type for the accumulator/result"); 606 607 // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32). 608 bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName()); 609 if (aType.getElementType().isF32() && !tf32Enabled) 610 return failure(); 611 612 // TODO: add an attribute to the op to customize this behavior. 613 std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt); 614 if (isa<IntegerType>(aType.getElementType())) 615 overflow = NVVM::MMAIntOverflow::satfinite; 616 617 SmallVector<Value> matA = 618 unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA); 619 SmallVector<Value> matB = 620 unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB); 621 SmallVector<Value> matC = 622 unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC); 623 624 Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); 625 Type intrinsicResTy = inferIntrinsicResultType( 626 typeConverter->convertType(op->getResultTypes()[0])); 627 628 // Bitcast the sparse metadata from vector<2xf16> to an i32. 629 Value sparseMetadata = adaptor.getSparseMetadata(); 630 if (sparseMetadata.getType() != 631 LLVM::getFixedVectorType(rewriter.getI16Type(), 2)) 632 return op->emitOpError() << "Expected metadata type to be LLVM " 633 "VectorType of 2 i16 elements"; 634 sparseMetadata = 635 b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata); 636 637 FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm( 638 b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB, 639 matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(), 640 intrinsicResTy); 641 if (failed(intrinsicResult)) 642 return failure(); 643 644 assert((*intrinsicResult).getNumResults() == 1 && 645 "expected inline asm op returns a single LLVM struct type"); 646 rewriter.replaceOp( 647 op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy, 648 (*intrinsicResult)->getResult(0), rewriter)); 649 return success(); 650 } 651 }; 652 653 struct NVGPUAsyncCopyLowering 654 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> { 655 using ConvertOpToLLVMPattern< 656 nvgpu::DeviceAsyncCopyOp>::ConvertOpToLLVMPattern; 657 658 LogicalResult 659 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, 660 ConversionPatternRewriter &rewriter) const override { 661 ImplicitLocOpBuilder b(op.getLoc(), rewriter); 662 Location loc = op.getLoc(); 663 auto dstMemrefType = cast<MemRefType>(op.getDst().getType()); 664 Value dstPtr = 665 getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(), 666 adaptor.getDstIndices(), rewriter); 667 FailureOr<unsigned> dstAddressSpace = 668 getTypeConverter()->getMemRefAddressSpace(dstMemrefType); 669 if (failed(dstAddressSpace)) 670 return rewriter.notifyMatchFailure( 671 loc, "destination memref address space not convertible to integer"); 672 673 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType()); 674 FailureOr<unsigned> srcAddressSpace = 675 getTypeConverter()->getMemRefAddressSpace(srcMemrefType); 676 if (failed(srcAddressSpace)) 677 return rewriter.notifyMatchFailure( 678 loc, "source memref address space not convertible to integer"); 679 680 Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(), 681 adaptor.getSrcIndices(), rewriter); 682 // Intrinsics takes a global pointer so we need an address space cast. 683 auto srcPointerGlobalType = LLVM::LLVMPointerType::get( 684 op->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace); 685 scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr); 686 int64_t dstElements = adaptor.getDstElements().getZExtValue(); 687 int64_t sizeInBytes = 688 (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; 689 // When the optional SrcElements argument is *not* present, the regular 690 // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global 691 // memory) to fill DstElements number of elements in the destination 692 // (shared memory). 693 Value srcBytes = adaptor.getSrcElements(); 694 if (srcBytes) { 695 // When the optional SrcElements argument is present, the source (global 696 // memory) of CpAsyncOp is read only for SrcElements number of elements. 697 // The rest of the DstElements in the destination (shared memory) are 698 // filled with zeros. 699 Value c3I32 = 700 b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3)); 701 Value bitwidth = b.create<LLVM::ConstantOp>( 702 b.getI32Type(), 703 b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth())); 704 Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes); 705 srcBytes = b.create<LLVM::LShrOp>( 706 b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32); 707 } 708 // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than 709 // 16 dst bytes. 710 NVVM::LoadCacheModifierKind cacheModifier = 711 (op.getBypassL1().value_or(false) && sizeInBytes == 16) 712 ? NVVM::LoadCacheModifierKind::CG 713 : NVVM::LoadCacheModifierKind::CA; 714 715 b.create<NVVM::CpAsyncOp>( 716 dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), 717 NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier), 718 srcBytes); 719 720 // Drop the result token. 721 Value zero = b.create<LLVM::ConstantOp>( 722 IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0)); 723 rewriter.replaceOp(op, zero); 724 return success(); 725 } 726 }; 727 728 struct NVGPUAsyncCreateGroupLowering 729 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCreateGroupOp> { 730 using ConvertOpToLLVMPattern< 731 nvgpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern; 732 733 LogicalResult 734 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, 735 ConversionPatternRewriter &rewriter) const override { 736 rewriter.create<NVVM::CpAsyncCommitGroupOp>(op.getLoc()); 737 // Drop the result token. 738 Value zero = rewriter.create<LLVM::ConstantOp>( 739 op->getLoc(), IntegerType::get(op.getContext(), 32), 740 rewriter.getI32IntegerAttr(0)); 741 rewriter.replaceOp(op, zero); 742 return success(); 743 } 744 }; 745 746 struct NVGPUAsyncWaitLowering 747 : public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncWaitOp> { 748 using ConvertOpToLLVMPattern< 749 nvgpu::DeviceAsyncWaitOp>::ConvertOpToLLVMPattern; 750 751 LogicalResult 752 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op, OpAdaptor adaptor, 753 ConversionPatternRewriter &rewriter) const override { 754 // If numGroup is not present pick 0 as a conservative correct value. 755 int32_t numGroups = adaptor.getNumGroups().value_or(0); 756 rewriter.create<NVVM::CpAsyncWaitGroupOp>(op.getLoc(), numGroups); 757 rewriter.eraseOp(op); 758 return success(); 759 } 760 }; 761 762 /// Creates mbarrier object in shared memory 763 struct NVGPUMBarrierCreateLowering 764 : public ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp> { 765 using ConvertOpToLLVMPattern<nvgpu::MBarrierCreateOp>::ConvertOpToLLVMPattern; 766 767 template <typename moduleT> 768 memref::GlobalOp generateGlobalBarrier(ConversionPatternRewriter &rewriter, 769 Operation *funcOp, moduleT moduleOp, 770 MemRefType barrierType) const { 771 SymbolTable symbolTable(moduleOp); 772 OpBuilder::InsertionGuard guard(rewriter); 773 rewriter.setInsertionPoint(&moduleOp.front()); 774 auto global = rewriter.create<memref::GlobalOp>( 775 funcOp->getLoc(), "__mbarrier", 776 /*sym_visibility=*/rewriter.getStringAttr("private"), 777 /*type=*/barrierType, 778 /*initial_value=*/ElementsAttr(), 779 /*constant=*/false, 780 /*alignment=*/rewriter.getI64IntegerAttr(8)); 781 symbolTable.insert(global); 782 return global; 783 } 784 785 LogicalResult 786 matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor, 787 ConversionPatternRewriter &rewriter) const override { 788 Operation *funcOp = op->getParentOp(); 789 MemRefType barrierType = nvgpu::getMBarrierMemrefType( 790 rewriter.getContext(), op.getBarriers().getType()); 791 792 memref::GlobalOp global; 793 if (auto moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>()) 794 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); 795 else if (auto moduleOp = funcOp->getParentOfType<ModuleOp>()) 796 global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); 797 798 rewriter.setInsertionPoint(op); 799 rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, barrierType, 800 global.getName()); 801 return success(); 802 } 803 }; 804 805 /// Base class for lowering mbarrier operations to nvvm intrinsics. 806 template <typename SourceOp> 807 struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> { 808 public: 809 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; 810 /// Returns the base pointer of the mbarrier object. 811 Value getMbarrierPtr(ImplicitLocOpBuilder &b, 812 nvgpu::MBarrierGroupType mbarType, Value memrefDesc, 813 Value mbarId, 814 ConversionPatternRewriter &rewriter) const { 815 MemRefType mbarrierMemrefType = 816 nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType); 817 return ConvertToLLVMPattern::getStridedElementPtr( 818 b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter); 819 } 820 }; 821 822 /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init` 823 struct NVGPUMBarrierInitLowering 824 : public MBarrierBasePattern<nvgpu::MBarrierInitOp> { 825 using MBarrierBasePattern<nvgpu::MBarrierInitOp>::MBarrierBasePattern; 826 827 LogicalResult 828 matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor, 829 ConversionPatternRewriter &rewriter) const override { 830 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 831 nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType(); 832 rewriter.setInsertionPoint(op); 833 Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(), 834 adaptor.getMbarId(), rewriter); 835 Value count = truncToI32(b, adaptor.getCount()); 836 if (isMbarrierShared(mbarrierType)) { 837 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>( 838 op, barrier, count, adaptor.getPredicate()); 839 } else { 840 rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count, 841 adaptor.getPredicate()); 842 } 843 return success(); 844 } 845 }; 846 847 /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive` 848 struct NVGPUMBarrierArriveLowering 849 : public MBarrierBasePattern<nvgpu::MBarrierArriveOp> { 850 using MBarrierBasePattern<nvgpu::MBarrierArriveOp>::MBarrierBasePattern; 851 LogicalResult 852 matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor, 853 ConversionPatternRewriter &rewriter) const override { 854 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 855 Value barrier = 856 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), 857 adaptor.getMbarId(), rewriter); 858 Type tokenType = getTypeConverter()->convertType( 859 nvgpu::MBarrierTokenType::get(op->getContext())); 860 if (isMbarrierShared(op.getBarriers().getType())) { 861 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveSharedOp>(op, tokenType, 862 barrier); 863 } else { 864 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, 865 barrier); 866 } 867 return success(); 868 } 869 }; 870 871 /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to 872 /// `nvvm.mbarrier.arrive.nocomplete` 873 struct NVGPUMBarrierArriveNoCompleteLowering 874 : public MBarrierBasePattern<nvgpu::MBarrierArriveNoCompleteOp> { 875 using MBarrierBasePattern< 876 nvgpu::MBarrierArriveNoCompleteOp>::MBarrierBasePattern; 877 LogicalResult 878 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor, 879 ConversionPatternRewriter &rewriter) const override { 880 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 881 Value barrier = 882 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), 883 adaptor.getMbarId(), rewriter); 884 Type tokenType = getTypeConverter()->convertType( 885 nvgpu::MBarrierTokenType::get(op->getContext())); 886 Value count = truncToI32(b, adaptor.getCount()); 887 if (isMbarrierShared(op.getBarriers().getType())) { 888 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>( 889 op, tokenType, barrier, count); 890 } else { 891 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteOp>( 892 op, tokenType, barrier, count); 893 } 894 return success(); 895 } 896 }; 897 898 /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait` 899 struct NVGPUMBarrierTestWaitLowering 900 : public MBarrierBasePattern<nvgpu::MBarrierTestWaitOp> { 901 using MBarrierBasePattern<nvgpu::MBarrierTestWaitOp>::MBarrierBasePattern; 902 LogicalResult 903 matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor, 904 ConversionPatternRewriter &rewriter) const override { 905 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 906 Value barrier = 907 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), 908 adaptor.getMbarId(), rewriter); 909 Type retType = rewriter.getI1Type(); 910 if (isMbarrierShared(op.getBarriers().getType())) { 911 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitSharedOp>( 912 op, retType, barrier, adaptor.getToken()); 913 } else { 914 rewriter.replaceOpWithNewOp<NVVM::MBarrierTestWaitOp>( 915 op, retType, barrier, adaptor.getToken()); 916 } 917 return success(); 918 } 919 }; 920 921 struct NVGPUMBarrierArriveExpectTxLowering 922 : public MBarrierBasePattern<nvgpu::MBarrierArriveExpectTxOp> { 923 using MBarrierBasePattern< 924 nvgpu::MBarrierArriveExpectTxOp>::MBarrierBasePattern; 925 LogicalResult 926 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor, 927 ConversionPatternRewriter &rewriter) const override { 928 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 929 Value barrier = 930 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), 931 adaptor.getMbarId(), rewriter); 932 Value txcount = truncToI32(b, adaptor.getTxcount()); 933 934 if (isMbarrierShared(op.getBarriers().getType())) { 935 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>( 936 op, barrier, txcount, adaptor.getPredicate()); 937 return success(); 938 } 939 940 rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>( 941 op, barrier, txcount, adaptor.getPredicate()); 942 return success(); 943 } 944 }; 945 946 struct NVGPUMBarrierTryWaitParityLowering 947 : public MBarrierBasePattern<nvgpu::MBarrierTryWaitParityOp> { 948 using MBarrierBasePattern< 949 nvgpu::MBarrierTryWaitParityOp>::MBarrierBasePattern; 950 LogicalResult 951 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor, 952 ConversionPatternRewriter &rewriter) const override { 953 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 954 Value barrier = 955 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), 956 adaptor.getMbarId(), rewriter); 957 Value ticks = truncToI32(b, adaptor.getTicks()); 958 Value phase = 959 b.create<LLVM::ZExtOp>(b.getI32Type(), adaptor.getPhaseParity()); 960 961 if (isMbarrierShared(op.getBarriers().getType())) { 962 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>( 963 op, barrier, phase, ticks); 964 return success(); 965 } 966 967 rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier, 968 phase, ticks); 969 return success(); 970 } 971 }; 972 973 struct NVGPUTmaAsyncLoadOpLowering 974 : public MBarrierBasePattern<nvgpu::TmaAsyncLoadOp> { 975 using MBarrierBasePattern<nvgpu::TmaAsyncLoadOp>::MBarrierBasePattern; 976 LogicalResult 977 matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor, 978 ConversionPatternRewriter &rewriter) const override { 979 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 980 auto srcMemrefType = cast<MemRefType>(op.getDst().getType()); 981 Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType, 982 adaptor.getDst(), {}, rewriter); 983 Value barrier = 984 getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), 985 adaptor.getMbarId(), rewriter); 986 987 SmallVector<Value> coords = adaptor.getCoordinates(); 988 for (auto [index, value] : llvm::enumerate(coords)) { 989 coords[index] = truncToI32(b, value); 990 } 991 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>( 992 op, dest, adaptor.getTensorMapDescriptor(), coords, barrier, 993 ValueRange{}, adaptor.getMulticastMask(), Value{}, 994 adaptor.getPredicate()); 995 return success(); 996 } 997 }; 998 999 struct NVGPUTmaAsyncStoreOpLowering 1000 : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> { 1001 using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern; 1002 LogicalResult 1003 matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor, 1004 ConversionPatternRewriter &rewriter) const override { 1005 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 1006 auto srcMemrefType = cast<MemRefType>(op.getSrc().getType()); 1007 Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType, 1008 adaptor.getSrc(), {}, rewriter); 1009 SmallVector<Value> coords = adaptor.getCoordinates(); 1010 for (auto [index, value] : llvm::enumerate(coords)) { 1011 coords[index] = truncToI32(b, value); 1012 } 1013 1014 rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>( 1015 op, adaptor.getTensorMapDescriptor(), dest, coords, 1016 adaptor.getPredicate()); 1017 return success(); 1018 } 1019 }; 1020 1021 struct NVGPUGenerateWarpgroupDescriptorLowering 1022 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> { 1023 using ConvertOpToLLVMPattern< 1024 nvgpu::WarpgroupGenerateDescriptorOp>::ConvertOpToLLVMPattern; 1025 1026 LogicalResult 1027 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op, OpAdaptor adaptor, 1028 ConversionPatternRewriter &rewriter) const override { 1029 1030 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 1031 1032 nvgpu::TensorMapSwizzleKind swizzleKind = 1033 op.getTensorMap().getType().getSwizzle(); 1034 1035 unsigned layout = 1036 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128 1037 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64 1038 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32 1039 : 1; 1040 unsigned swizzle = 1041 (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1 1042 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2 1043 : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3 1044 : 0; 1045 1046 auto ti64 = b.getIntegerType(64); 1047 auto makeConst = [&](uint64_t index) -> Value { 1048 return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index)); 1049 }; 1050 auto shiftLeft = [&](Value value, unsigned shift) -> Value { 1051 return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift)); 1052 }; 1053 auto shiftRight = [&](Value value, unsigned shift) -> Value { 1054 return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift)); 1055 }; 1056 auto insertBit = [&](Value desc, Value val, int startBit) { 1057 return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit)); 1058 }; 1059 1060 int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); 1061 uint64_t strideDimVal = (layout << 3) >> exclude4LSB; 1062 uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB; 1063 uint64_t offsetVal = 0; 1064 1065 Value strideDim = makeConst(strideDimVal); 1066 Value leadDim = makeConst(leadDimVal); 1067 1068 Value baseAddr = getStridedElementPtr( 1069 op->getLoc(), cast<MemRefType>(op.getTensor().getType()), 1070 adaptor.getTensor(), {}, rewriter); 1071 Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr); 1072 // Just use 14 bits for base address 1073 Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50); 1074 1075 int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32, 1076 startLeadBit = 16, startBaseAddrBit = 0; 1077 Value dsc = makeConst(0); 1078 // // [62,64) swizzle type 1079 dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit); 1080 // // [49,52) base_offset 1081 dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit); 1082 // // [32,46) stride 1083 dsc = insertBit(dsc, strideDim, startStrideBit); 1084 // // [16,30) leading dimension 1085 dsc = insertBit(dsc, leadDim, startLeadBit); 1086 // // [0,14) start_address 1087 dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); 1088 1089 LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " 1090 << "leading_off:" << leadDimVal << "\t" 1091 << "stride_off :" << strideDimVal << "\t" 1092 << "base_offset:" << offsetVal << "\t" 1093 << "layout_type:" << swizzle << " (" 1094 << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) 1095 << ")\n start_addr : " << baseAddr << "\n"); 1096 1097 rewriter.replaceOp(op, dsc); 1098 return success(); 1099 } 1100 }; 1101 1102 static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) { 1103 return b.create<LLVM::ConstantOp>(b.getIntegerType(64), 1104 b.getI32IntegerAttr(index)); 1105 } 1106 1107 /// Returns a Value that holds data type enum that is expected by CUDA driver. 1108 static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) { 1109 // Enum is from CUDA driver API 1110 // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html 1111 enum CUtensorMapDataTypeEnum { 1112 CU_TENSOR_MAP_DATA_TYPE_UINT8 = 0, 1113 CU_TENSOR_MAP_DATA_TYPE_UINT16, 1114 CU_TENSOR_MAP_DATA_TYPE_UINT32, 1115 CU_TENSOR_MAP_DATA_TYPE_INT32, 1116 CU_TENSOR_MAP_DATA_TYPE_UINT64, 1117 CU_TENSOR_MAP_DATA_TYPE_INT64, 1118 CU_TENSOR_MAP_DATA_TYPE_FLOAT16, 1119 CU_TENSOR_MAP_DATA_TYPE_FLOAT32, 1120 CU_TENSOR_MAP_DATA_TYPE_FLOAT64, 1121 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 1122 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ, 1123 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32, 1124 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ 1125 }; 1126 1127 if (type.isUnsignedInteger(8)) 1128 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8); 1129 if (type.isUnsignedInteger(16)) 1130 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16); 1131 if (type.isUnsignedInteger(32)) 1132 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32); 1133 if (type.isUnsignedInteger(64)) 1134 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64); 1135 if (type.isSignlessInteger(32)) 1136 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32); 1137 if (type.isSignlessInteger(64)) 1138 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64); 1139 if (type.isF16()) 1140 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16); 1141 if (type.isF32()) 1142 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32); 1143 if (type.isF64()) 1144 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64); 1145 if (type.isBF16()) 1146 return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16); 1147 1148 llvm_unreachable("Not supported data type"); 1149 } 1150 1151 struct NVGPUTmaCreateDescriptorOpLowering 1152 : public ConvertOpToLLVMPattern<nvgpu::TmaCreateDescriptorOp> { 1153 using ConvertOpToLLVMPattern< 1154 nvgpu::TmaCreateDescriptorOp>::ConvertOpToLLVMPattern; 1155 LogicalResult 1156 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor, 1157 ConversionPatternRewriter &rewriter) const override { 1158 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 1159 auto llvmPointerType = LLVM::LLVMPointerType::get(op->getContext()); 1160 Type llvmInt64Type = IntegerType::get(op->getContext(), 64); 1161 1162 Value tensorElementType = 1163 elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType()); 1164 auto promotedOperands = getTypeConverter()->promoteOperands( 1165 b.getLoc(), op->getOperands(), adaptor.getOperands(), b); 1166 1167 Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type, 1168 makeI64Const(b, 5)); 1169 for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) { 1170 Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType, 1171 boxArrayPtr, makeI64Const(b, index)); 1172 b.create<LLVM::StoreOp>(value, gep); 1173 } 1174 1175 nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType(); 1176 // Set Arguments for the function call 1177 SmallVector<Value> arguments; 1178 arguments.push_back(promotedOperands[0]); // rank 1179 arguments.push_back(promotedOperands[1]); // descriptor 1180 arguments.push_back(tensorElementType); // data type 1181 arguments.push_back( 1182 makeI64Const(b, (int)desc.getInterleave())); // interleave 1183 arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle 1184 arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo 1185 arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob 1186 arguments.push_back(boxArrayPtr); // box dimensions 1187 1188 // Set data types of the arguments 1189 SmallVector<Type> argTypes = { 1190 llvmInt64Type, /* int64_t tensorRank */ 1191 llvmPointerType, /* ptr */ 1192 llvmInt64Type, /* int64_t */ 1193 llvmInt64Type, /* int64_t */ 1194 llvmInt64Type, /* int64_t */ 1195 llvmInt64Type, /* int64_t */ 1196 llvmInt64Type, /* int64_t */ 1197 llvmPointerType /* ptr */ 1198 }; 1199 FunctionCallBuilder hostRegisterCallBuilder = { 1200 "mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes}; 1201 Value tensorMap = 1202 hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult(); 1203 1204 rewriter.replaceOp(op, tensorMap); 1205 return success(); 1206 } 1207 }; 1208 1209 struct NVGPUWarpgroupMmaOpLowering 1210 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> { 1211 using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern; 1212 1213 /// This is a helper class to generate required NVVM Ops for warp-group level 1214 /// matrix multiplication. 1215 /// When the given GEMM shape is larger than the shape of 1216 /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp 1217 /// Op(s), group and execute them asynchronously. The class also handles 1218 /// waiting for completion and iterates through WarpgroupMatrixDescriptor to 1219 /// create descriptors for each instruction. 1220 /// 1221 /// For example this is the case when the shape of GEMM is 128x128x128 1222 /// 1223 /// nvvm.wgmma.fence.aligned 1224 /// 1225 /// nvvm.wgmma.mma.async descA, descB 1226 /// iterate(descA, descB) 1227 /// nvvm.wgmma.mma.async descA, descB 1228 /// [6x times more] 1229 /// 1230 /// nvvm.wgmma.group.sync.aligned 1231 /// nvvm.wgmma.wait.group.sync [groupId] 1232 /// 1233 class WarpgroupGemm { 1234 nvgpu::WarpgroupMmaOp op; 1235 ImplicitLocOpBuilder b; 1236 OpAdaptor adaptor; 1237 1238 // Entire shape of the given Op 1239 int64_t totalM, totalN, totalK; 1240 1241 // Shape of one wgmma instruction 1242 int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0; 1243 1244 // Iteration counts for GEMM 1245 int iterationM = 0, iterationN = 0, iterationK = 0; 1246 1247 /// The function returns the shape of wgmma instruction that is defined in 1248 /// PTX programming guide. 1249 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape 1250 void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) { 1251 wgmmaM = 64; 1252 wgmmaN = sizeN; 1253 if (inputElemType.isTF32()) { 1254 wgmmaK = 8; 1255 } else if (inputElemType.isF16() || inputElemType.isBF16()) { 1256 wgmmaK = 16; 1257 } else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) || 1258 inputElemType.isInteger(16)) { 1259 wgmmaK = 32; 1260 } else if (inputElemType.isInteger(1)) { 1261 wgmmaK = 256; 1262 } else { 1263 llvm_unreachable("msg: not supported K shape"); 1264 } 1265 LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM 1266 << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n"); 1267 } 1268 1269 /// Generates WGMMATypesAttr from MLIR Type 1270 NVVM::WGMMATypesAttr generateWgmmaType(Type type, 1271 bool useF32 = false) const { 1272 auto getWgmmaType = [=](Type elemType) { 1273 if (elemType.isF32() || elemType.isTF32()) 1274 return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32; 1275 if (elemType.isF16()) 1276 return NVVM::WGMMATypes::f16; 1277 if (elemType.isBF16()) 1278 return NVVM::WGMMATypes::bf16; 1279 if (isa<Float8E4M3FNType>(elemType)) 1280 return NVVM::WGMMATypes::e4m3; 1281 if (isa<Float8E5M2Type>(elemType)) 1282 return NVVM::WGMMATypes::e5m2; 1283 if (elemType.isInteger(1)) 1284 return NVVM::WGMMATypes::b1; 1285 if (elemType.isInteger(8)) 1286 return NVVM::WGMMATypes::s8; 1287 if (elemType.isUnsignedInteger(8)) 1288 return NVVM::WGMMATypes::u8; 1289 if (elemType.isInteger(32)) 1290 return NVVM::WGMMATypes::s32; 1291 llvm_unreachable("unsupported type"); 1292 }; 1293 return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type)); 1294 } 1295 1296 /// Generates layout attribute for the input matrix for wgmma instruction 1297 NVVM::MMALayoutAttr 1298 generateWgmmaLayout(std::optional<bool> transpose) const { 1299 if (transpose.value_or(false)) 1300 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col); 1301 return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row); 1302 } 1303 1304 /// Generates shape attribute for wgmma instruction 1305 NVVM::MMAShapeAttr generateWgmmaShape() const { 1306 return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK); 1307 } 1308 1309 /// Generates scale attributes of output matrix for wgmma instruction 1310 NVVM::WGMMAScaleOutAttr generateScaleOut() const { 1311 return NVVM::WGMMAScaleOutAttr::get(op->getContext(), 1312 NVVM::WGMMAScaleOut::one); 1313 } 1314 /// Generates scale attributes of input matrix for wgmma instruction 1315 NVVM::WGMMAScaleInAttr generateScaleIn() const { 1316 return NVVM::WGMMAScaleInAttr::get(op->getContext(), 1317 NVVM::WGMMAScaleIn::one); 1318 } 1319 1320 /// Basic function to generate Add 1321 Value makeAdd(Value lhs, Value rhs) { 1322 return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); 1323 }; 1324 1325 /// Moves the descriptor pointer of matrix-A for the next wgmma instruction. 1326 /// Currently, it only handles row-major. 1327 /// 1328 /// It moves the pointer like below for [128][64] size: 1329 /// +2 +4 +6 1330 /// ↓ ↓ ↓ 1331 /// descA ---> +--+--+--+--+ 1332 /// |->|->|->|->| 1333 /// | | | | | 1334 /// | | | | | 1335 /// | | | | | 1336 /// descA+512---> +-----------+ 1337 /// | | | | | 1338 /// | | | | | 1339 /// | | | | | 1340 /// | | | | | 1341 /// +-----------+ 1342 /// 1343 Value iterateDescriptorA(Value desc, int i, int j, int k) { 1344 MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor(); 1345 Type elemA = matrixTypeA.getElementType(); 1346 int byte = elemA.getIntOrFloatBitWidth() / 8; 1347 int tileShapeA = matrixTypeA.getDimSize(1); 1348 int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte; 1349 incrementVal = incrementVal >> exclude4LSB; 1350 LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k 1351 << "] [wgmma descriptors] Descriptor A + " 1352 << incrementVal << " | \t "); 1353 if (!incrementVal) 1354 return desc; 1355 return makeAdd(desc, makeI64Const(b, incrementVal)); 1356 } 1357 1358 /// Moves the descriptor pointer of matrix-B for the next wgmma instruction. 1359 /// Currently, it only handles column-major. 1360 /// 1361 /// It moves the pointer like below for [128][64] size: 1362 /// descB ---> +--+--+--+--+--+--+--+--+ 1363 /// |↓ | | | | | | | | 1364 /// |↓ | | | | | | | | 1365 /// |↓ | | | | | | | | 1366 /// |↓ | | | | | | | | 1367 /// +--+--+--+--+--+--+--+--+ 1368 /// 1369 Value iterateDescriptorB(Value desc, int i, int j, int k) { 1370 MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor(); 1371 Type elemB = matrixTypeB.getElementType(); 1372 int byte = elemB.getIntOrFloatBitWidth() / 8; 1373 int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte; 1374 incrementVal = incrementVal >> exclude4LSB; 1375 LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); 1376 if (!incrementVal) 1377 return desc; 1378 return makeAdd(desc, makeI64Const(b, incrementVal)); 1379 } 1380 1381 /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix 1382 /// descriptors and arranges them based on induction variables: i, j, and k. 1383 Value generateWgmma(int i, int j, int k, Value matrixC) { 1384 LLVM_DEBUG(DBGS() << "\t wgmma." 1385 << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK 1386 << "(A[" << (iterationM * wgmmaM) << ":" 1387 << (iterationM * wgmmaM) + wgmmaM << "][" 1388 << (iterationK * wgmmaK) << ":" 1389 << (iterationK * wgmmaK + wgmmaK) << "] * " 1390 << " B[" << (iterationK * wgmmaK) << ":" 1391 << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" 1392 << wgmmaN << "])\n"); 1393 1394 Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); 1395 Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); 1396 1397 Type elemA = op.getDescriptorA().getType().getTensor().getElementType(); 1398 NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA); 1399 1400 Type elemB = op.getDescriptorB().getType().getTensor().getElementType(); 1401 NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB); 1402 1403 Type elemD = op.getMatrixC().getType().getFragmented().getElementType(); 1404 NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true); 1405 1406 NVVM::MMAShapeAttr shape = generateWgmmaShape(); 1407 NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut(); 1408 NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn(); 1409 NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA()); 1410 NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(!op.getTransposeB()); 1411 1412 auto overflow = NVVM::MMAIntOverflowAttr::get( 1413 op->getContext(), NVVM::MMAIntOverflow::wrapped); 1414 1415 return b.create<NVVM::WgmmaMmaAsyncOp>( 1416 matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA, 1417 itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB, 1418 overflow); 1419 } 1420 1421 /// Generates multiple wgmma instructions to complete the given GEMM shape 1422 Value generateWgmmaGroup() { 1423 Value wgmmaResult = 1424 b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType()); 1425 1426 // Perform GEMM 1427 SmallVector<Value> wgmmaResults; 1428 for (int i = 0; i < iterationM; ++i) { 1429 Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i); 1430 for (int j = 0; j < iterationN; ++j) 1431 for (int k = 0; k < iterationK; ++k) 1432 matrixC = generateWgmma(i, j, k, matrixC); 1433 wgmmaResults.push_back(matrixC); 1434 } 1435 for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) { 1436 wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(), 1437 wgmmaResult, matrix, idx); 1438 } 1439 return wgmmaResult; 1440 } 1441 1442 public: 1443 WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b, 1444 OpAdaptor adaptor) 1445 : op(op), b(b), adaptor(adaptor) { 1446 // Find the entire GEMM Shape 1447 totalM = op.getDescriptorA().getType().getTensor().getDimSize(0); 1448 totalN = op.getDescriptorB().getType().getTensor().getDimSize(1); 1449 totalK = op.getDescriptorA().getType().getTensor().getDimSize(1); 1450 LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN 1451 << "] += A[" << totalM << "][" << totalK << "] * B[" 1452 << totalK << "][" << totalN << "] ---===\n"); 1453 1454 // Find the shape for one wgmma instruction 1455 findWgmmaShape( 1456 totalM, totalN, 1457 op.getDescriptorA().getType().getTensor().getElementType()); 1458 1459 // Iterations counts to complete the given shape with wgmma shape 1460 iterationM = totalM / wgmmaM; 1461 iterationN = totalN / wgmmaN; 1462 iterationK = totalK / wgmmaK; 1463 } 1464 1465 /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It 1466 /// includes generating a fence Op (WgmmaFenceAlignedOp) before the 1467 /// instructions and group synchronization, as well as waiting 1468 /// (WgmmaGroupSyncAlignedOp) for group synchronization 1469 /// (WgmmaWaitGroupSyncOp) after the instructions. 1470 Value generateWarpgroupMma() { 1471 b.create<NVVM::WgmmaFenceAlignedOp>(); 1472 Value wgmmaResult = generateWgmmaGroup(); 1473 b.create<NVVM::WgmmaGroupSyncAlignedOp>(); 1474 b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup()); 1475 return wgmmaResult; 1476 } 1477 }; 1478 LogicalResult 1479 matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor, 1480 ConversionPatternRewriter &rewriter) const override { 1481 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 1482 1483 // Step 1. Build a helper class 1484 WarpgroupGemm warpgroupGemm(op, b, adaptor); 1485 1486 // Step 2. Get the entire GEMM Shape 1487 Value wgmmaResult = warpgroupGemm.generateWarpgroupMma(); 1488 1489 // Step 3. Replace fragmented result struct with the op results 1490 rewriter.replaceOp(op, wgmmaResult); 1491 return success(); 1492 } 1493 }; 1494 1495 struct NVGPUWarpgroupMmaStoreOpLowering 1496 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> { 1497 using ConvertOpToLLVMPattern< 1498 nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern; 1499 1500 /// This function stores a fragmented register matrix owned by a warp group 1501 /// (128 threads) into a memref. Each thread has 64 registers, each the size 1502 /// of a struct. 1503 /// Here is what each threads (T) holds, each `d` is struct value with a 1504 /// number. 1505 /// 1506 /// Threads in warp-group (128 threads) and what they owns in the matrixD: 1507 /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N] 1508 /// 32-63 Warp-1 -> MatrixD[16:31][0:N] 1509 /// 64-95 Warp-2 -> MatrixD[32:47][0:N] 1510 /// 96-127 Warp-3 -> MatrixD[48:64][0:N] 1511 /// 1512 /// Matrix-D: 1513 /// +______________________________________________________________________+ 1514 /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 | 1515 /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY| 1516 /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY| 1517 /// ..| .........|.........|.........|.........|........|...........|........| 1518 /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW| 1519 /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW| 1520 /// ..| .........|.........|.........|.........|........|...........|........| 1521 /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........| 1522 /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........| 1523 /// ..| .........|.........|.........|.........|........|...........|........| 1524 /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........| 1525 /// ..| .........|.........|.........|.........|........|...........|........| 1526 /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........| 1527 /// ..| .........|.........|.........|.........|........|...........|........| 1528 /// +______________________________________________________________________+ 1529 /// 1530 /// \param rewriter: The pattern rewriter. 1531 /// \param matrixD: Result of the warp-group MMA operation (fragmented 1532 /// matrix). It is holded by a thread and a struct with 64 elements. 1533 /// \param dstMemref: The memref where the registers will be stored. 1534 /// \param offset: the offset within the memref where the registers will be 1535 /// stored. 1536 void storeFragmentedMatrix(ImplicitLocOpBuilder &b, Value matrixD, 1537 TypedValue<MemRefType> dstMemref, 1538 int offset) const { 1539 Type i32 = b.getI32Type(); 1540 1541 auto makeConst = [&](int32_t index) -> Value { 1542 return b.create<LLVM::ConstantOp>(i32, b.getI32IntegerAttr(index)); 1543 }; 1544 Value c1 = makeConst(1); 1545 Value c2 = makeConst(2); 1546 Value c4 = makeConst(4); 1547 Value c8 = makeConst(8); 1548 Value c16 = makeConst(16); 1549 Value warpSize = makeConst(kWarpSize); 1550 1551 auto makeMul = [&](Value lhs, Value rhs) -> Value { 1552 return b.create<LLVM::MulOp>(lhs.getType(), lhs, rhs); 1553 }; 1554 auto makeAdd = [&](Value lhs, Value rhs) -> Value { 1555 return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs); 1556 }; 1557 1558 auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, 1559 TypedValue<::mlir::MemRefType> memref) { 1560 Type it = b.getIndexType(); 1561 Value idx = b.create<arith::IndexCastOp>(it, x); 1562 Value idy0 = b.create<arith::IndexCastOp>(it, y); 1563 Value idy1 = b.create<arith::IndexCastOp>(it, makeAdd(y, c1)); 1564 Value d0 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i); 1565 Value d1 = b.create<LLVM::ExtractValueOp>(wgmmaResult, i + 1); 1566 b.create<memref::StoreOp>(d0, memref, ValueRange{idx, idy0}); 1567 b.create<memref::StoreOp>(d1, memref, ValueRange{idx, idy1}); 1568 }; 1569 1570 Value tidx = b.create<NVVM::ThreadIdXOp>(i32); 1571 Value laneId = b.create<LLVM::URemOp>(i32, tidx, warpSize); 1572 Value warpId = b.create<LLVM::UDivOp>(i32, tidx, warpSize); 1573 Value lane4Id = b.create<LLVM::UDivOp>(i32, laneId, c4); 1574 Value lane4modId = b.create<LLVM::URemOp>(i32, laneId, c4); 1575 1576 Value tj = makeMul(lane4modId, c2); 1577 Value ti = makeAdd(lane4Id, makeMul(warpId, c16)); 1578 if (offset) 1579 ti = makeAdd(ti, makeConst(offset)); 1580 1581 auto structType = cast<LLVM::LLVMStructType>(matrixD.getType()); 1582 1583 // Number of 32-bit registers owns per thread 1584 constexpr unsigned numAdjacentRegisters = 2; 1585 // Number of 8x8 matrices one below another per warp 1586 constexpr unsigned numStackedMatrices = 2; 1587 1588 size_t storeCount = (structType.getBody().size() / 1589 (numStackedMatrices * numAdjacentRegisters)); 1590 1591 for (size_t i = 0; i < numStackedMatrices; ++i) { 1592 Value idx = makeAdd(ti, makeMul(makeConst(i), c8)); 1593 for (size_t j = 0; j < storeCount; ++j) { 1594 Value idy = makeAdd(tj, makeMul(makeConst(j), c8)); 1595 size_t structIndex = (i * numAdjacentRegisters) + 1596 (j * (numStackedMatrices * numAdjacentRegisters)); 1597 makeExtractAndStore(structIndex, matrixD, idx, idy, dstMemref); 1598 } 1599 } 1600 } 1601 1602 LogicalResult 1603 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor, 1604 ConversionPatternRewriter &rewriter) const override { 1605 int offset = 0; 1606 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 1607 Value matriDValue = adaptor.getMatrixD(); 1608 auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType()); 1609 for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) { 1610 auto structType = cast<LLVM::LLVMStructType>(matrixD); 1611 Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx); 1612 storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset); 1613 offset += structType.getBody().size(); 1614 } 1615 rewriter.eraseOp(op); 1616 return success(); 1617 } 1618 }; 1619 1620 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering 1621 : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> { 1622 using ConvertOpToLLVMPattern< 1623 nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern; 1624 LogicalResult 1625 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor, 1626 ConversionPatternRewriter &rewriter) const override { 1627 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 1628 LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>( 1629 getTypeConverter()->convertType(op.getMatrixC().getType())); 1630 Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front()) 1631 .getBody() 1632 .front(); 1633 Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType)); 1634 Value packStruct = b.create<LLVM::UndefOp>(packStructType); 1635 SmallVector<Value> innerStructs; 1636 // Unpack the structs and set all values to zero 1637 for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) { 1638 auto structType = cast<LLVM::LLVMStructType>(s); 1639 Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx); 1640 for (unsigned i = 0; i < structType.getBody().size(); ++i) { 1641 structValue = b.create<LLVM::InsertValueOp>( 1642 structType, structValue, zero, ArrayRef<int64_t>({i})); 1643 } 1644 innerStructs.push_back(structValue); 1645 } 1646 // Pack the inner structs into a single struct 1647 for (auto [idx, matrix] : llvm::enumerate(innerStructs)) { 1648 packStruct = b.create<LLVM::InsertValueOp>(packStruct.getType(), 1649 packStruct, matrix, idx); 1650 } 1651 rewriter.replaceOp(op, packStruct); 1652 return success(); 1653 } 1654 }; 1655 1656 struct NVGPUTmaPrefetchOpLowering 1657 : public ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp> { 1658 using ConvertOpToLLVMPattern<nvgpu::TmaPrefetchOp>::ConvertOpToLLVMPattern; 1659 LogicalResult 1660 matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor, 1661 ConversionPatternRewriter &rewriter) const override { 1662 rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>( 1663 op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate()); 1664 return success(); 1665 } 1666 }; 1667 1668 struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> { 1669 using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern; 1670 LogicalResult 1671 matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor, 1672 ConversionPatternRewriter &rewriter) const override { 1673 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 1674 auto i64Ty = b.getI64Type(); 1675 auto f32Ty = b.getF32Type(); 1676 VectorType inTy = op.getIn().getType(); 1677 // apply rcp.approx.ftz.f on each element in vector. 1678 auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) { 1679 Value ret1DVec = b.create<LLVM::UndefOp>(llvm1DVectorTy); 1680 int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements(); 1681 for (int i = 0; i < numElems; i++) { 1682 Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i)); 1683 Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx); 1684 Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem); 1685 ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx); 1686 } 1687 return ret1DVec; 1688 }; 1689 if (inTy.getRank() == 1) { 1690 rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn())); 1691 return success(); 1692 } 1693 return LLVM::detail::handleMultidimensionalVectors( 1694 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()), 1695 [&](Type llvm1DVectorTy, ValueRange operands) -> Value { 1696 OpAdaptor adaptor(operands); 1697 return convert1DVec(llvm1DVectorTy, adaptor.getIn()); 1698 }, 1699 rewriter); 1700 } 1701 }; 1702 } // namespace 1703 1704 void mlir::populateNVGPUToNVVMConversionPatterns( 1705 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 1706 patterns.add< 1707 NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create 1708 NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init 1709 NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive 1710 NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete 1711 NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity 1712 NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity 1713 NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load 1714 NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store 1715 NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor 1716 NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor 1717 NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx 1718 NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor 1719 NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma 1720 NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store 1721 NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator 1722 MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, 1723 NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, 1724 NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter); 1725 } 1726